MLIR 23.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1//===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the OpenMP dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/Attributes.h"
21#include "mlir/IR/Matchers.h"
24#include "mlir/IR/SymbolTable.h"
26
27#include "llvm/ADT/ArrayRef.h"
28#include "llvm/ADT/PostOrderIterator.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/STLForwardCompat.h"
31#include "llvm/ADT/SmallString.h"
32#include "llvm/ADT/StringExtras.h"
33#include "llvm/ADT/StringRef.h"
34#include "llvm/ADT/TypeSwitch.h"
35#include "llvm/ADT/bit.h"
36#include "llvm/Support/InterleavedRange.h"
37#include <cstddef>
38#include <iterator>
39#include <optional>
40#include <variant>
41
42#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
43#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
44#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
45#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
46
47using namespace mlir;
48using namespace mlir::omp;
49
52 return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
53}
54
57 return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
58}
59
62 return intArray.empty() ? nullptr : DenseI64ArrayAttr::get(ctx, intArray);
63}
64
65namespace {
66struct MemRefPointerLikeModel
67 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
68 MemRefType> {
69 Type getElementType(Type pointer) const {
70 return llvm::cast<MemRefType>(pointer).getElementType();
71 }
72};
73
74struct LLVMPointerPointerLikeModel
75 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
76 LLVM::LLVMPointerType> {
77 Type getElementType(Type pointer) const { return Type(); }
78};
79} // namespace
80
81/// Generate a name of a canonical loop nest of the format
82/// `<prefix>(_r<idx>_s<idx>)*`. Hereby, `_r<idx>` identifies the region
83/// argument index of an operation that has multiple regions, if the operation
84/// has multiple regions.
85/// `_s<idx>` identifies the position of an operation within a region, where
86/// only operations that may potentially contain loops ("container operations"
87/// i.e. have region arguments) are counted. Again, it is omitted if there is
88/// only one such operation in a region. If there are canonical loops nested
89/// inside each other, also may also use the format `_d<num>` where <num> is the
90/// nesting depth of the loop.
91///
92/// The generated name is a best-effort to make canonical loop unique within an
93/// SSA namespace. This also means that regions with IsolatedFromAbove property
94/// do not consider any parents or siblings.
95static std::string generateLoopNestingName(StringRef prefix,
96 CanonicalLoopOp op) {
97 struct Component {
98 /// If true, this component describes a region operand of an operation (the
99 /// operand's owner) If false, this component describes an operation located
100 /// in a parent region
101 bool isRegionArgOfOp;
102 bool skip = false;
103 bool isUnique = false;
104
105 size_t idx;
106 Operation *op;
107 Region *parentRegion;
108 size_t loopDepth;
109
110 Operation *&getOwnerOp() {
111 assert(isRegionArgOfOp && "Must describe a region operand");
112 return op;
113 }
114 size_t &getArgIdx() {
115 assert(isRegionArgOfOp && "Must describe a region operand");
116 return idx;
117 }
118
119 Operation *&getContainerOp() {
120 assert(!isRegionArgOfOp && "Must describe a operation of a region");
121 return op;
122 }
123 size_t &getOpPos() {
124 assert(!isRegionArgOfOp && "Must describe a operation of a region");
125 return idx;
126 }
127 bool isLoopOp() const {
128 assert(!isRegionArgOfOp && "Must describe a operation of a region");
129 return isa<CanonicalLoopOp>(op);
130 }
131 Region *&getParentRegion() {
132 assert(!isRegionArgOfOp && "Must describe a operation of a region");
133 return parentRegion;
134 }
135 size_t &getLoopDepth() {
136 assert(!isRegionArgOfOp && "Must describe a operation of a region");
137 return loopDepth;
138 }
139
140 void skipIf(bool v = true) { skip = skip || v; }
141 };
142
143 // List of ancestors, from inner to outer.
144 // Alternates between
145 // * region argument of an operation
146 // * operation within a region
147 SmallVector<Component> components;
148
149 // Gather a list of parent regions and operations, and the position within
150 // their parent
151 Operation *o = op.getOperation();
152 while (o) {
153 // Operation within a region
154 Region *r = o->getParentRegion();
155 if (!r)
156 break;
157
158 llvm::ReversePostOrderTraversal<Block *> traversal(&r->getBlocks().front());
159 size_t idx = 0;
160 bool found = false;
161 size_t sequentialIdx = -1;
162 bool isOnlyContainerOp = true;
163 for (Block *b : traversal) {
164 for (Operation &op : *b) {
165 if (&op == o && !found) {
166 sequentialIdx = idx;
167 found = true;
168 }
169 if (op.getNumRegions()) {
170 idx += 1;
171 if (idx > 1)
172 isOnlyContainerOp = false;
173 }
174 if (found && !isOnlyContainerOp)
175 break;
176 }
177 }
178
179 Component &containerOpInRegion = components.emplace_back();
180 containerOpInRegion.isRegionArgOfOp = false;
181 containerOpInRegion.isUnique = isOnlyContainerOp;
182 containerOpInRegion.getContainerOp() = o;
183 containerOpInRegion.getOpPos() = sequentialIdx;
184 containerOpInRegion.getParentRegion() = r;
185
186 Operation *parent = r->getParentOp();
187
188 // Region argument of an operation
189 Component &regionArgOfOperation = components.emplace_back();
190 regionArgOfOperation.isRegionArgOfOp = true;
191 regionArgOfOperation.isUnique = true;
192 regionArgOfOperation.getArgIdx() = 0;
193 regionArgOfOperation.getOwnerOp() = parent;
194
195 // The IsolatedFromAbove trait of the parent operation implies that each
196 // individual region argument has its own separate namespace, so no
197 // ambiguity.
198 if (!parent || parent->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>())
199 break;
200
201 // Component only needed if operation has multiple region operands. Region
202 // arguments may be optional, but we currently do not consider this.
203 if (parent->getRegions().size() > 1) {
204 auto getRegionIndex = [](Operation *o, Region *r) {
205 for (auto [idx, region] : llvm::enumerate(o->getRegions())) {
206 if (&region == r)
207 return idx;
208 }
209 llvm_unreachable("Region not child of its parent operation");
210 };
211 regionArgOfOperation.isUnique = false;
212 regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
213 }
214
215 // next parent
216 o = parent;
217 }
218
219 // Determine whether a region-argument component is not needed
220 for (Component &c : components)
221 c.skipIf(c.isRegionArgOfOp && c.isUnique);
222
223 // Find runs of nested loops and determine each loop's depth in the loop nest
224 size_t numSurroundingLoops = 0;
225 for (Component &c : llvm::reverse(components)) {
226 if (c.skip)
227 continue;
228
229 // non-skipped multi-argument operands interrupt the loop nest
230 if (c.isRegionArgOfOp) {
231 numSurroundingLoops = 0;
232 continue;
233 }
234
235 // Multiple loops in a region means each of them is the outermost loop of a
236 // new loop nest
237 if (!c.isUnique)
238 numSurroundingLoops = 0;
239
240 c.getLoopDepth() = numSurroundingLoops;
241
242 // Next loop is surrounded by one more loop
243 if (isa<CanonicalLoopOp>(c.getContainerOp()))
244 numSurroundingLoops += 1;
245 }
246
247 // In loop nests, skip all but the innermost loop that contains the depth
248 // number
249 bool isLoopNest = false;
250 for (Component &c : components) {
251 if (c.skip || c.isRegionArgOfOp)
252 continue;
253
254 if (!isLoopNest && c.getLoopDepth() >= 1) {
255 // Innermost loop of a loop nest of at least two loops
256 isLoopNest = true;
257 } else if (isLoopNest) {
258 // Non-innermost loop of a loop nest
259 c.skipIf(c.isUnique);
260
261 // If there is no surrounding loop left, this must have been the outermost
262 // loop; leave loop-nest mode for the next iteration
263 if (c.getLoopDepth() == 0)
264 isLoopNest = false;
265 }
266 }
267
268 // Skip non-loop unambiguous regions (but they should interrupt loop nests, so
269 // we mark them as skipped only after computing loop nests)
270 for (Component &c : components)
271 c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
272 !isa<CanonicalLoopOp>(c.getContainerOp()));
273
274 // Components can be skipped if they are already disambiguated by their parent
275 // (or does not have a parent)
276 bool newRegion = true;
277 for (Component &c : llvm::reverse(components)) {
278 c.skipIf(newRegion && c.isUnique);
279
280 // non-skipped components disambiguate unique children
281 if (!c.skip)
282 newRegion = true;
283
284 // ...except canonical loops that need a suffix for each nest
285 if (!c.isRegionArgOfOp && c.getContainerOp())
286 newRegion = false;
287 }
288
289 // Compile the nesting name string
290 SmallString<64> Name{prefix};
291 llvm::raw_svector_ostream NameOS(Name);
292 for (auto &c : llvm::reverse(components)) {
293 if (c.skip)
294 continue;
295
296 if (c.isRegionArgOfOp)
297 NameOS << "_r" << c.getArgIdx();
298 else if (c.getLoopDepth() >= 1)
299 NameOS << "_d" << c.getLoopDepth();
300 else
301 NameOS << "_s" << c.getOpPos();
302 }
303
304 return NameOS.str().str();
305}
306
307void OpenMPDialect::initialize() {
308 addOperations<
309#define GET_OP_LIST
310#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
311 >();
312 addAttributes<
313#define GET_ATTRDEF_LIST
314#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
315 >();
316 addTypes<
317#define GET_TYPEDEF_LIST
318#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
319 >();
320
321 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
322
323 MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
324 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
325 *getContext());
326
327 // Attach default offload module interface to module op to access
328 // offload functionality through
329 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
330 *getContext());
331
332 // Attach default declare target interfaces to operations which can be marked
333 // as declare target (Global Operations and Functions/Subroutines in dialects
334 // that Fortran (or other languages that lower to MLIR) translates too
335 mlir::LLVM::GlobalOp::attachInterface<
337 *getContext());
338 mlir::LLVM::LLVMFuncOp::attachInterface<
340 *getContext());
341 mlir::func::FuncOp::attachInterface<
343}
344
345//===----------------------------------------------------------------------===//
346// Parser and printer for Allocate Clause
347//===----------------------------------------------------------------------===//
348
349/// Parse an allocate clause with allocators and a list of operands with types.
350///
351/// allocate-operand-list :: = allocate-operand |
352/// allocator-operand `,` allocate-operand-list
353/// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
354/// ssa-id-and-type ::= ssa-id `:` type
355static ParseResult parseAllocateAndAllocator(
356 OpAsmParser &parser,
358 SmallVectorImpl<Type> &allocateTypes,
360 SmallVectorImpl<Type> &allocatorTypes) {
361
362 return parser.parseCommaSeparatedList([&]() {
364 Type type;
365 if (parser.parseOperand(operand) || parser.parseColonType(type))
366 return failure();
367 allocatorVars.push_back(operand);
368 allocatorTypes.push_back(type);
369 if (parser.parseArrow())
370 return failure();
371 if (parser.parseOperand(operand) || parser.parseColonType(type))
372 return failure();
373
374 allocateVars.push_back(operand);
375 allocateTypes.push_back(type);
376 return success();
377 });
378}
379
380/// Print allocate clause
382 OperandRange allocateVars,
383 TypeRange allocateTypes,
384 OperandRange allocatorVars,
385 TypeRange allocatorTypes) {
386 for (unsigned i = 0; i < allocateVars.size(); ++i) {
387 std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
388 p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
389 p << allocateVars[i] << " : " << allocateTypes[i] << separator;
390 }
391}
392
393//===----------------------------------------------------------------------===//
394// Parser and printer for a clause attribute (StringEnumAttr)
395//===----------------------------------------------------------------------===//
396
397template <typename ClauseAttr>
398static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
399 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
400 StringRef enumStr;
401 SMLoc loc = parser.getCurrentLocation();
402 if (parser.parseKeyword(&enumStr))
403 return failure();
404 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
405 attr = ClauseAttr::get(parser.getContext(), *enumValue);
406 return success();
407 }
408 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
409}
410
411template <typename ClauseAttr>
412static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
413 p << stringifyEnum(attr.getValue());
414}
415
416//===----------------------------------------------------------------------===//
417// Parser and printer for Linear Clause
418//===----------------------------------------------------------------------===//
419
420/// linear ::= `linear` `(` linear-list `)`
421/// linear-list := linear-val | linear-val linear-list
422/// linear-val := ssa-id-and-type `=` ssa-id-and-type
423/// | `val` `(` ssa-id-and-type `=` ssa-id-and-type `)`
424/// | `ref` `(` ssa-id-and-type `=` ssa-id-and-type `)`
425/// | `uval` `(` ssa-id-and-type `=` ssa-id-and-type `)`
426static ParseResult parseLinearClause(
427 OpAsmParser &parser,
429 SmallVectorImpl<Type> &linearTypes,
431 SmallVectorImpl<Type> &linearStepTypes, ArrayAttr &linearModifiers) {
432 SmallVector<Attribute> modifiers;
433 auto result = parser.parseCommaSeparatedList([&]() {
435 Type type, stepType;
437
438 std::optional<omp::LinearModifier> linearModifier;
439 if (succeeded(parser.parseOptionalKeyword("val"))) {
440 linearModifier = omp::LinearModifier::val;
441 } else if (succeeded(parser.parseOptionalKeyword("ref"))) {
442 linearModifier = omp::LinearModifier::ref;
443 } else if (succeeded(parser.parseOptionalKeyword("uval"))) {
444 linearModifier = omp::LinearModifier::uval;
445 }
446
447 bool hasLinearModifierParens = linearModifier.has_value();
448 if (hasLinearModifierParens && parser.parseLParen())
449 return failure();
450
451 if (parser.parseOperand(var) || parser.parseColonType(type) ||
452 parser.parseEqual() || parser.parseOperand(stepVar) ||
453 parser.parseColonType(stepType))
454 return failure();
455
456 if (hasLinearModifierParens && parser.parseRParen())
457 return failure();
458
459 linearVars.push_back(var);
460 linearTypes.push_back(type);
461 linearStepVars.push_back(stepVar);
462 linearStepTypes.push_back(stepType);
463 if (linearModifier) {
464 modifiers.push_back(
465 omp::LinearModifierAttr::get(parser.getContext(), *linearModifier));
466 } else {
467 modifiers.push_back(UnitAttr::get(parser.getContext()));
468 }
469 return success();
470 });
471 if (failed(result))
472 return failure();
473 linearModifiers = ArrayAttr::get(parser.getContext(), modifiers);
474 return success();
475}
476
477/// Print Linear Clause
479 ValueRange linearVars, TypeRange linearTypes,
480 ValueRange linearStepVars, TypeRange stepVarTypes,
481 ArrayAttr linearModifiers) {
482 size_t linearVarsSize = linearVars.size();
483 for (unsigned i = 0; i < linearVarsSize; ++i) {
484 if (i != 0)
485 p << ", ";
486 // Print modifier keyword wrapper if present.
487 Attribute modAttr = linearModifiers ? linearModifiers[i] : nullptr;
488 auto mod = modAttr ? dyn_cast<omp::LinearModifierAttr>(modAttr) : nullptr;
489 if (mod) {
490 p << omp::stringifyLinearModifier(mod.getValue()) << "(";
491 }
492 p << linearVars[i] << " : " << linearTypes[i];
493 p << " = " << linearStepVars[i] << " : " << stepVarTypes[i];
494 if (mod)
495 p << ")";
496 }
497}
498
499//===----------------------------------------------------------------------===//
500// Verifier for Linear modifier
501//===----------------------------------------------------------------------===//
502
503/// OpenMP 5.2, Section 5.4.6: "A linear-modifier may be specified as ref or
504/// uval only on a declare simd directive."
505/// Also verifies that modifier count matches variable count.
506static LogicalResult
507verifyLinearModifiers(Operation *op, std::optional<ArrayAttr> linearModifiers,
508 OperandRange linearVars, bool isDeclareSimd = false) {
509 if (!linearModifiers)
510 return success();
511 if (linearModifiers->size() != linearVars.size())
512 return op->emitOpError()
513 << "expected as many linear modifiers as linear variables";
514 if (!isDeclareSimd) {
515 for (Attribute attr : *linearModifiers) {
516 if (!attr)
517 continue;
518 auto modAttr = dyn_cast<omp::LinearModifierAttr>(attr);
519 if (!modAttr)
520 continue;
521 omp::LinearModifier mod = modAttr.getValue();
522 if (mod == omp::LinearModifier::ref || mod == omp::LinearModifier::uval)
523 return op->emitOpError()
524 << "linear modifier '" << omp::stringifyLinearModifier(mod)
525 << "' may only be specified on a declare simd directive";
526 }
527 }
528 return success();
529}
530
531//===----------------------------------------------------------------------===//
532// Verifier for Nontemporal Clause
533//===----------------------------------------------------------------------===//
534
535static LogicalResult verifyNontemporalClause(Operation *op,
536 OperandRange nontemporalVars) {
537
538 // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
539 DenseSet<Value> nontemporalItems;
540 for (const auto &it : nontemporalVars)
541 if (!nontemporalItems.insert(it).second)
542 return op->emitOpError() << "nontemporal variable used more than once";
543
544 return success();
545}
546
547//===----------------------------------------------------------------------===//
548// Parser, verifier and printer for Aligned Clause
549//===----------------------------------------------------------------------===//
550static LogicalResult verifyAlignedClause(Operation *op,
551 std::optional<ArrayAttr> alignments,
552 OperandRange alignedVars) {
553 // Check if number of alignment values equals to number of aligned variables
554 if (!alignedVars.empty()) {
555 if (!alignments || alignments->size() != alignedVars.size())
556 return op->emitOpError()
557 << "expected as many alignment values as aligned variables";
558 } else {
559 if (alignments)
560 return op->emitOpError() << "unexpected alignment values attribute";
561 return success();
562 }
563
564 // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
565 DenseSet<Value> alignedItems;
566 for (auto it : alignedVars)
567 if (!alignedItems.insert(it).second)
568 return op->emitOpError() << "aligned variable used more than once";
569
570 if (!alignments)
571 return success();
572
573 // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
574 for (unsigned i = 0; i < (*alignments).size(); ++i) {
575 if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
576 if (intAttr.getValue().sle(0))
577 return op->emitOpError() << "alignment should be greater than 0";
578 } else {
579 return op->emitOpError() << "expected integer alignment";
580 }
581 }
582
583 return success();
584}
585
586/// aligned ::= `aligned` `(` aligned-list `)`
587/// aligned-list := aligned-val | aligned-val aligned-list
588/// aligned-val := ssa-id-and-type `->` alignment
589static ParseResult
592 SmallVectorImpl<Type> &alignedTypes,
593 ArrayAttr &alignmentsAttr) {
594 SmallVector<Attribute> alignmentVec;
595 if (failed(parser.parseCommaSeparatedList([&]() {
596 if (parser.parseOperand(alignedVars.emplace_back()) ||
597 parser.parseColonType(alignedTypes.emplace_back()) ||
598 parser.parseArrow() ||
599 parser.parseAttribute(alignmentVec.emplace_back())) {
600 return failure();
601 }
602 return success();
603 })))
604 return failure();
605 SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
606 alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
607 return success();
608}
609
610/// Print Aligned Clause
612 ValueRange alignedVars, TypeRange alignedTypes,
613 std::optional<ArrayAttr> alignments) {
614 for (unsigned i = 0; i < alignedVars.size(); ++i) {
615 if (i != 0)
616 p << ", ";
617 p << alignedVars[i] << " : " << alignedVars[i].getType();
618 p << " -> " << (*alignments)[i];
619 }
620}
621
622//===----------------------------------------------------------------------===//
623// Parser, printer and verifier for Schedule Clause
624//===----------------------------------------------------------------------===//
625
626static ParseResult
628 SmallVectorImpl<SmallString<12>> &modifiers) {
629 if (modifiers.size() > 2)
630 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
631 for (const auto &mod : modifiers) {
632 // Translate the string. If it has no value, then it was not a valid
633 // modifier!
634 auto symbol = symbolizeScheduleModifier(mod);
635 if (!symbol)
636 return parser.emitError(parser.getNameLoc())
637 << " unknown modifier type: " << mod;
638 }
639
640 // If we have one modifier that is "simd", then stick a "none" modiifer in
641 // index 0.
642 if (modifiers.size() == 1) {
643 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
644 modifiers.push_back(modifiers[0]);
645 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
646 }
647 } else if (modifiers.size() == 2) {
648 // If there are two modifier:
649 // First modifier should not be simd, second one should be simd
650 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
651 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
652 return parser.emitError(parser.getNameLoc())
653 << " incorrect modifier order";
654 }
655 return success();
656}
657
658/// schedule ::= `schedule` `(` sched-list `)`
659/// sched-list ::= sched-val | sched-val sched-list |
660/// sched-val `,` sched-modifier
661/// sched-val ::= sched-with-chunk | sched-wo-chunk
662/// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
663/// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
664/// sched-wo-chunk ::= `auto` | `runtime`
665/// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
666/// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
667static ParseResult
668parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
669 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
670 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
671 Type &chunkType) {
672 StringRef keyword;
673 if (parser.parseKeyword(&keyword))
674 return failure();
675 std::optional<mlir::omp::ClauseScheduleKind> schedule =
676 symbolizeClauseScheduleKind(keyword);
677 if (!schedule)
678 return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
679
680 scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
681 switch (*schedule) {
682 case ClauseScheduleKind::Static:
683 case ClauseScheduleKind::Dynamic:
684 case ClauseScheduleKind::Guided:
685 if (succeeded(parser.parseOptionalEqual())) {
686 chunkSize = OpAsmParser::UnresolvedOperand{};
687 if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
688 return failure();
689 } else {
690 chunkSize = std::nullopt;
691 }
692 break;
693 case ClauseScheduleKind::Auto:
694 case ClauseScheduleKind::Runtime:
695 case ClauseScheduleKind::Distribute:
696 chunkSize = std::nullopt;
697 }
698
699 // If there is a comma, we have one or more modifiers..
701 while (succeeded(parser.parseOptionalComma())) {
702 StringRef mod;
703 if (parser.parseKeyword(&mod))
704 return failure();
705 modifiers.push_back(mod);
706 }
707
708 if (verifyScheduleModifiers(parser, modifiers))
709 return failure();
710
711 if (!modifiers.empty()) {
712 SMLoc loc = parser.getCurrentLocation();
713 if (std::optional<ScheduleModifier> mod =
714 symbolizeScheduleModifier(modifiers[0])) {
715 scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
716 } else {
717 return parser.emitError(loc, "invalid schedule modifier");
718 }
719 // Only SIMD attribute is allowed here!
720 if (modifiers.size() > 1) {
721 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
722 scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
723 }
724 }
725
726 return success();
727}
728
729/// Print schedule clause
731 ClauseScheduleKindAttr scheduleKind,
732 ScheduleModifierAttr scheduleMod,
733 UnitAttr scheduleSimd, Value scheduleChunk,
734 Type scheduleChunkType) {
735 p << stringifyClauseScheduleKind(scheduleKind.getValue());
736 if (scheduleChunk)
737 p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
738 if (scheduleMod)
739 p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
740 if (scheduleSimd)
741 p << ", simd";
742}
743
744//===----------------------------------------------------------------------===//
745// Parser and printer for Order Clause
746//===----------------------------------------------------------------------===//
747
748// order ::= `order` `(` [order-modifier ':'] concurrent `)`
749// order-modifier ::= reproducible | unconstrained
750static ParseResult parseOrderClause(OpAsmParser &parser,
751 ClauseOrderKindAttr &order,
752 OrderModifierAttr &orderMod) {
753 StringRef enumStr;
754 SMLoc loc = parser.getCurrentLocation();
755 if (parser.parseKeyword(&enumStr))
756 return failure();
757 if (std::optional<OrderModifier> enumValue =
758 symbolizeOrderModifier(enumStr)) {
759 orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
760 if (parser.parseOptionalColon())
761 return failure();
762 loc = parser.getCurrentLocation();
763 if (parser.parseKeyword(&enumStr))
764 return failure();
765 }
766 if (std::optional<ClauseOrderKind> enumValue =
767 symbolizeClauseOrderKind(enumStr)) {
768 order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
769 return success();
770 }
771 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
772}
773
775 ClauseOrderKindAttr order,
776 OrderModifierAttr orderMod) {
777 if (orderMod)
778 p << stringifyOrderModifier(orderMod.getValue()) << ":";
779 if (order)
780 p << stringifyClauseOrderKind(order.getValue());
781}
782
783template <typename ClauseTypeAttr, typename ClauseType>
784static ParseResult
785parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
786 std::optional<OpAsmParser::UnresolvedOperand> &operand,
787 Type &operandType,
788 std::optional<ClauseType> (*symbolizeClause)(StringRef),
789 StringRef clauseName) {
790 StringRef enumStr;
791 if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
792 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
793 prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
794 if (parser.parseComma())
795 return failure();
796 } else {
797 return parser.emitError(parser.getCurrentLocation())
798 << "invalid " << clauseName << " modifier : '" << enumStr << "'";
799 ;
800 }
801 }
802
804 if (succeeded(parser.parseOperand(var))) {
805 operand = var;
806 } else {
807 return parser.emitError(parser.getCurrentLocation())
808 << "expected " << clauseName << " operand";
809 }
810
811 if (operand.has_value()) {
812 if (parser.parseColonType(operandType))
813 return failure();
814 }
815
816 return success();
817}
818
819template <typename ClauseTypeAttr, typename ClauseType>
820static void
822 ClauseTypeAttr prescriptiveness, Value operand,
823 mlir::Type operandType,
824 StringRef (*stringifyClauseType)(ClauseType)) {
825
826 if (prescriptiveness)
827 p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
828
829 if (operand)
830 p << operand << ": " << operandType;
831}
832
833//===----------------------------------------------------------------------===//
834// Parser and printer for grainsize Clause
835//===----------------------------------------------------------------------===//
836
837// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
838static ParseResult
839parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
840 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
841 Type &grainsizeType) {
843 parser, grainsizeMod, grainsize, grainsizeType,
844 &symbolizeClauseGrainsizeType, "grainsize");
845}
846
848 ClauseGrainsizeTypeAttr grainsizeMod,
849 Value grainsize, mlir::Type grainsizeType) {
851 p, op, grainsizeMod, grainsize, grainsizeType,
852 &stringifyClauseGrainsizeType);
853}
854
855//===----------------------------------------------------------------------===//
856// Parser and printer for num_tasks Clause
857//===----------------------------------------------------------------------===//
858
859// numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
860static ParseResult
861parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
862 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
863 Type &numTasksType) {
865 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
866 "num_tasks");
867}
868
870 ClauseNumTasksTypeAttr numTasksMod,
871 Value numTasks, mlir::Type numTasksType) {
873 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
874}
875
876//===----------------------------------------------------------------------===//
877// Parsers for operations including clauses that define entry block arguments.
878//===----------------------------------------------------------------------===//
879
880namespace {
881struct MapParseArgs {
882 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
883 SmallVectorImpl<Type> &types;
884 MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
885 SmallVectorImpl<Type> &types)
886 : vars(vars), types(types) {}
887};
888struct PrivateParseArgs {
889 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
890 llvm::SmallVectorImpl<Type> &types;
891 ArrayAttr &syms;
892 UnitAttr &needsBarrier;
893 DenseI64ArrayAttr *mapIndices;
894 PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
895 SmallVectorImpl<Type> &types, ArrayAttr &syms,
896 UnitAttr &needsBarrier,
897 DenseI64ArrayAttr *mapIndices = nullptr)
898 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
899 mapIndices(mapIndices) {}
900};
901
902struct ReductionParseArgs {
903 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
904 SmallVectorImpl<Type> &types;
905 DenseBoolArrayAttr &byref;
906 ArrayAttr &syms;
907 ReductionModifierAttr *modifier;
908 ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
909 SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
910 ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
911 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
912};
913
914struct AllRegionParseArgs {
915 std::optional<MapParseArgs> hasDeviceAddrArgs;
916 std::optional<MapParseArgs> hostEvalArgs;
917 std::optional<ReductionParseArgs> inReductionArgs;
918 std::optional<MapParseArgs> mapArgs;
919 std::optional<PrivateParseArgs> privateArgs;
920 std::optional<ReductionParseArgs> reductionArgs;
921 std::optional<ReductionParseArgs> taskReductionArgs;
922 std::optional<MapParseArgs> useDeviceAddrArgs;
923 std::optional<MapParseArgs> useDevicePtrArgs;
924};
925} // namespace
926
927static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
928 return "private_barrier";
929}
930
931static ParseResult parseClauseWithRegionArgs(
932 OpAsmParser &parser,
936 ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
937 DenseBoolArrayAttr *byref = nullptr,
938 ReductionModifierAttr *modifier = nullptr,
939 UnitAttr *needsBarrier = nullptr) {
941 SmallVector<int64_t> mapIndicesVec;
942 SmallVector<bool> isByRefVec;
943 unsigned regionArgOffset = regionPrivateArgs.size();
944
945 if (parser.parseLParen())
946 return failure();
947
948 if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
949 StringRef enumStr;
950 if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
951 parser.parseComma())
952 return failure();
953 std::optional<ReductionModifier> enumValue =
954 symbolizeReductionModifier(enumStr);
955 if (!enumValue.has_value())
956 return failure();
957 *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
958 if (!*modifier)
959 return failure();
960 }
961
962 if (parser.parseCommaSeparatedList([&]() {
963 if (byref)
964 isByRefVec.push_back(
965 parser.parseOptionalKeyword("byref").succeeded());
966
967 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
968 return failure();
969
970 if (parser.parseOperand(operands.emplace_back()) ||
971 parser.parseArrow() ||
972 parser.parseArgument(regionPrivateArgs.emplace_back()))
973 return failure();
974
975 if (mapIndices) {
976 if (parser.parseOptionalLSquare().succeeded()) {
977 if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
978 parser.parseInteger(mapIndicesVec.emplace_back()) ||
979 parser.parseRSquare())
980 return failure();
981 } else {
982 mapIndicesVec.push_back(-1);
983 }
984 }
985
986 return success();
987 }))
988 return failure();
989
990 if (parser.parseColon())
991 return failure();
992
993 if (parser.parseCommaSeparatedList([&]() {
994 if (parser.parseType(types.emplace_back()))
995 return failure();
996
997 return success();
998 }))
999 return failure();
1000
1001 if (operands.size() != types.size())
1002 return failure();
1003
1004 if (parser.parseRParen())
1005 return failure();
1006
1007 if (needsBarrier) {
1009 .succeeded())
1010 *needsBarrier = mlir::UnitAttr::get(parser.getContext());
1011 }
1012
1013 auto *argsBegin = regionPrivateArgs.begin();
1014 MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
1015 argsBegin + regionArgOffset + types.size());
1016 for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
1017 prv.type = type;
1018 }
1019
1020 if (symbols) {
1021 SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
1022 *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
1023 }
1024
1025 if (!mapIndicesVec.empty())
1026 *mapIndices =
1027 mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
1028
1029 if (byref)
1030 *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
1031
1032 return success();
1033}
1034
1035static ParseResult parseBlockArgClause(
1036 OpAsmParser &parser,
1038 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
1039 if (succeeded(parser.parseOptionalKeyword(keyword))) {
1040 if (!mapArgs)
1041 return failure();
1042
1043 if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
1044 entryBlockArgs)))
1045 return failure();
1046 }
1047 return success();
1048}
1049
1050static ParseResult parseBlockArgClause(
1051 OpAsmParser &parser,
1053 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
1054 if (succeeded(parser.parseOptionalKeyword(keyword))) {
1055 if (!privateArgs)
1056 return failure();
1057
1058 if (failed(parseClauseWithRegionArgs(
1059 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
1060 &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1061 /*modifier=*/nullptr, &privateArgs->needsBarrier)))
1062 return failure();
1063 }
1064 return success();
1065}
1066
1067static ParseResult parseBlockArgClause(
1068 OpAsmParser &parser,
1070 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
1071 if (succeeded(parser.parseOptionalKeyword(keyword))) {
1072 if (!reductionArgs)
1073 return failure();
1074 if (failed(parseClauseWithRegionArgs(
1075 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1076 &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
1077 reductionArgs->modifier)))
1078 return failure();
1079 }
1080 return success();
1081}
1082
1083static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
1084 AllRegionParseArgs args) {
1086
1087 if (failed(parseBlockArgClause(parser, entryBlockArgs, "has_device_addr",
1088 args.hasDeviceAddrArgs)))
1089 return parser.emitError(parser.getCurrentLocation())
1090 << "invalid `has_device_addr` format";
1091
1092 if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
1093 args.hostEvalArgs)))
1094 return parser.emitError(parser.getCurrentLocation())
1095 << "invalid `host_eval` format";
1096
1097 if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
1098 args.inReductionArgs)))
1099 return parser.emitError(parser.getCurrentLocation())
1100 << "invalid `in_reduction` format";
1101
1102 if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
1103 args.mapArgs)))
1104 return parser.emitError(parser.getCurrentLocation())
1105 << "invalid `map_entries` format";
1106
1107 if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
1108 args.privateArgs)))
1109 return parser.emitError(parser.getCurrentLocation())
1110 << "invalid `private` format";
1111
1112 if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
1113 args.reductionArgs)))
1114 return parser.emitError(parser.getCurrentLocation())
1115 << "invalid `reduction` format";
1116
1117 if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
1118 args.taskReductionArgs)))
1119 return parser.emitError(parser.getCurrentLocation())
1120 << "invalid `task_reduction` format";
1121
1122 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
1123 args.useDeviceAddrArgs)))
1124 return parser.emitError(parser.getCurrentLocation())
1125 << "invalid `use_device_addr` format";
1126
1127 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
1128 args.useDevicePtrArgs)))
1129 return parser.emitError(parser.getCurrentLocation())
1130 << "invalid `use_device_addr` format";
1131
1132 return parser.parseRegion(region, entryBlockArgs);
1133}
1134
1135// These parseXyz functions correspond to the custom<Xyz> definitions
1136// in the .td file(s).
1137static ParseResult parseTargetOpRegion(
1138 OpAsmParser &parser, Region &region,
1140 SmallVectorImpl<Type> &hasDeviceAddrTypes,
1142 SmallVectorImpl<Type> &hostEvalTypes,
1144 SmallVectorImpl<Type> &inReductionTypes,
1145 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1147 SmallVectorImpl<Type> &mapTypes,
1149 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1150 UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
1151 AllRegionParseArgs args;
1152 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1153 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1154 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1155 inReductionByref, inReductionSyms);
1156 args.mapArgs.emplace(mapVars, mapTypes);
1157 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1158 privateNeedsBarrier, &privateMaps);
1159 return parseBlockArgRegion(parser, region, args);
1160}
1161
1163 OpAsmParser &parser, Region &region,
1165 SmallVectorImpl<Type> &inReductionTypes,
1166 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1168 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1169 UnitAttr &privateNeedsBarrier) {
1170 AllRegionParseArgs args;
1171 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1172 inReductionByref, inReductionSyms);
1173 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1174 privateNeedsBarrier);
1175 return parseBlockArgRegion(parser, region, args);
1176}
1177
1179 OpAsmParser &parser, Region &region,
1181 SmallVectorImpl<Type> &inReductionTypes,
1182 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1184 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1185 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1187 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1188 ArrayAttr &reductionSyms) {
1189 AllRegionParseArgs args;
1190 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1191 inReductionByref, inReductionSyms);
1192 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1193 privateNeedsBarrier);
1194 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1195 reductionSyms, &reductionMod);
1196 return parseBlockArgRegion(parser, region, args);
1197}
1198
1199static ParseResult parsePrivateRegion(
1200 OpAsmParser &parser, Region &region,
1202 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1203 UnitAttr &privateNeedsBarrier) {
1204 AllRegionParseArgs args;
1205 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1206 privateNeedsBarrier);
1207 return parseBlockArgRegion(parser, region, args);
1208}
1209
1211 OpAsmParser &parser, Region &region,
1213 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1214 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1216 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1217 ArrayAttr &reductionSyms) {
1218 AllRegionParseArgs args;
1219 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1220 privateNeedsBarrier);
1221 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1222 reductionSyms, &reductionMod);
1223 return parseBlockArgRegion(parser, region, args);
1224}
1225
1226static ParseResult parseTaskReductionRegion(
1227 OpAsmParser &parser, Region &region,
1229 SmallVectorImpl<Type> &taskReductionTypes,
1230 DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
1231 AllRegionParseArgs args;
1232 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1233 taskReductionByref, taskReductionSyms);
1234 return parseBlockArgRegion(parser, region, args);
1235}
1236
1238 OpAsmParser &parser, Region &region,
1240 SmallVectorImpl<Type> &useDeviceAddrTypes,
1242 SmallVectorImpl<Type> &useDevicePtrTypes) {
1243 AllRegionParseArgs args;
1244 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1245 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1246 return parseBlockArgRegion(parser, region, args);
1247}
1248
1249//===----------------------------------------------------------------------===//
1250// Printers for operations including clauses that define entry block arguments.
1251//===----------------------------------------------------------------------===//
1252
1253namespace {
1254struct MapPrintArgs {
1255 ValueRange vars;
1256 TypeRange types;
1257 MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
1258};
1259struct PrivatePrintArgs {
1260 ValueRange vars;
1261 TypeRange types;
1262 ArrayAttr syms;
1263 UnitAttr needsBarrier;
1264 DenseI64ArrayAttr mapIndices;
1265 PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
1266 UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
1267 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1268 mapIndices(mapIndices) {}
1269};
1270struct ReductionPrintArgs {
1271 ValueRange vars;
1272 TypeRange types;
1273 DenseBoolArrayAttr byref;
1274 ArrayAttr syms;
1275 ReductionModifierAttr modifier;
1276 ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
1277 ArrayAttr syms, ReductionModifierAttr mod = nullptr)
1278 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1279};
1280struct AllRegionPrintArgs {
1281 std::optional<MapPrintArgs> hasDeviceAddrArgs;
1282 std::optional<MapPrintArgs> hostEvalArgs;
1283 std::optional<ReductionPrintArgs> inReductionArgs;
1284 std::optional<MapPrintArgs> mapArgs;
1285 std::optional<PrivatePrintArgs> privateArgs;
1286 std::optional<ReductionPrintArgs> reductionArgs;
1287 std::optional<ReductionPrintArgs> taskReductionArgs;
1288 std::optional<MapPrintArgs> useDeviceAddrArgs;
1289 std::optional<MapPrintArgs> useDevicePtrArgs;
1290};
1291} // namespace
1292
1294 OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1295 ValueRange argsSubrange, ValueRange operands, TypeRange types,
1296 ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
1297 DenseBoolArrayAttr byref = nullptr,
1298 ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
1299 if (argsSubrange.empty())
1300 return;
1301
1302 p << clauseName << "(";
1303
1304 if (modifier)
1305 p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
1306
1307 if (!symbols) {
1308 llvm::SmallVector<Attribute> values(operands.size(), nullptr);
1309 symbols = ArrayAttr::get(ctx, values);
1310 }
1311
1312 if (!mapIndices) {
1313 llvm::SmallVector<int64_t> values(operands.size(), -1);
1314 mapIndices = DenseI64ArrayAttr::get(ctx, values);
1315 }
1316
1317 if (!byref) {
1318 mlir::SmallVector<bool> values(operands.size(), false);
1319 byref = DenseBoolArrayAttr::get(ctx, values);
1320 }
1321
1322 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1323 mapIndices.asArrayRef(),
1324 byref.asArrayRef()),
1325 p, [&p](auto t) {
1326 auto [op, arg, sym, map, isByRef] = t;
1327 if (isByRef)
1328 p << "byref ";
1329 if (sym)
1330 p << sym << " ";
1331
1332 p << op << " -> " << arg;
1333
1334 if (map != -1)
1335 p << " [map_idx=" << map << "]";
1336 });
1337 p << " : ";
1338 llvm::interleaveComma(types, p);
1339 p << ") ";
1340
1341 if (needsBarrier)
1342 p << getPrivateNeedsBarrierSpelling() << " ";
1343}
1344
1346 StringRef clauseName, ValueRange argsSubrange,
1347 std::optional<MapPrintArgs> mapArgs) {
1348 if (mapArgs)
1349 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
1350 mapArgs->types);
1351}
1352
1354 StringRef clauseName, ValueRange argsSubrange,
1355 std::optional<PrivatePrintArgs> privateArgs) {
1356 if (privateArgs)
1358 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1359 privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1360 /*modifier=*/nullptr, privateArgs->needsBarrier);
1361}
1362
1363static void
1364printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1365 ValueRange argsSubrange,
1366 std::optional<ReductionPrintArgs> reductionArgs) {
1367 if (reductionArgs)
1368 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
1369 reductionArgs->vars, reductionArgs->types,
1370 reductionArgs->syms, /*mapIndices=*/nullptr,
1371 reductionArgs->byref, reductionArgs->modifier);
1372}
1373
1375 const AllRegionPrintArgs &args) {
1376 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1377 MLIRContext *ctx = op->getContext();
1378
1379 printBlockArgClause(p, ctx, "has_device_addr",
1380 iface.getHasDeviceAddrBlockArgs(),
1381 args.hasDeviceAddrArgs);
1382 printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
1383 args.hostEvalArgs);
1384 printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
1385 args.inReductionArgs);
1386 printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
1387 args.mapArgs);
1388 printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
1389 args.privateArgs);
1390 printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
1391 args.reductionArgs);
1392 printBlockArgClause(p, ctx, "task_reduction",
1393 iface.getTaskReductionBlockArgs(),
1394 args.taskReductionArgs);
1395 printBlockArgClause(p, ctx, "use_device_addr",
1396 iface.getUseDeviceAddrBlockArgs(),
1397 args.useDeviceAddrArgs);
1398 printBlockArgClause(p, ctx, "use_device_ptr",
1399 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1400
1401 p.printRegion(region, /*printEntryBlockArgs=*/false);
1402}
1403
1404// These parseXyz functions correspond to the custom<Xyz> definitions
1405// in the .td file(s).
1407 OpAsmPrinter &p, Operation *op, Region &region,
1408 ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
1409 ValueRange hostEvalVars, TypeRange hostEvalTypes,
1410 ValueRange inReductionVars, TypeRange inReductionTypes,
1411 DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
1412 ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
1413 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1414 DenseI64ArrayAttr privateMaps) {
1415 AllRegionPrintArgs args;
1416 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1417 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1418 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1419 inReductionByref, inReductionSyms);
1420 args.mapArgs.emplace(mapVars, mapTypes);
1421 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1422 privateNeedsBarrier, privateMaps);
1423 printBlockArgRegion(p, op, region, args);
1424}
1425
1427 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1428 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1429 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1430 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1431 AllRegionPrintArgs args;
1432 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1433 inReductionByref, inReductionSyms);
1434 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1435 privateNeedsBarrier,
1436 /*mapIndices=*/nullptr);
1437 printBlockArgRegion(p, op, region, args);
1438}
1439
1441 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1442 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1443 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1444 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1445 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1446 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1447 ArrayAttr reductionSyms) {
1448 AllRegionPrintArgs args;
1449 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1450 inReductionByref, inReductionSyms);
1451 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1452 privateNeedsBarrier,
1453 /*mapIndices=*/nullptr);
1454 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1455 reductionSyms, reductionMod);
1456 printBlockArgRegion(p, op, region, args);
1457}
1458
1460 ValueRange privateVars, TypeRange privateTypes,
1461 ArrayAttr privateSyms,
1462 UnitAttr privateNeedsBarrier) {
1463 AllRegionPrintArgs args;
1464 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1465 privateNeedsBarrier,
1466 /*mapIndices=*/nullptr);
1467 printBlockArgRegion(p, op, region, args);
1468}
1469
1471 OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
1472 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1473 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1474 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1475 ArrayAttr reductionSyms) {
1476 AllRegionPrintArgs args;
1477 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1478 privateNeedsBarrier,
1479 /*mapIndices=*/nullptr);
1480 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1481 reductionSyms, reductionMod);
1482 printBlockArgRegion(p, op, region, args);
1483}
1484
1486 Region &region,
1487 ValueRange taskReductionVars,
1488 TypeRange taskReductionTypes,
1489 DenseBoolArrayAttr taskReductionByref,
1490 ArrayAttr taskReductionSyms) {
1491 AllRegionPrintArgs args;
1492 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1493 taskReductionByref, taskReductionSyms);
1494 printBlockArgRegion(p, op, region, args);
1495}
1496
1498 Region &region,
1499 ValueRange useDeviceAddrVars,
1500 TypeRange useDeviceAddrTypes,
1501 ValueRange useDevicePtrVars,
1502 TypeRange useDevicePtrTypes) {
1503 AllRegionPrintArgs args;
1504 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1505 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1506 printBlockArgRegion(p, op, region, args);
1507}
1508
1509template <typename ParsePrefixFn>
1510static ParseResult parseSplitIteratedList(
1511 OpAsmParser &parser,
1513 SmallVectorImpl<Type> &iteratedTypes,
1515 SmallVectorImpl<Type> &plainTypes, ParsePrefixFn &&parsePrefix) {
1516
1517 return parser.parseCommaSeparatedList([&]() -> ParseResult {
1518 if (failed(parsePrefix()))
1519 return failure();
1520
1522 Type ty;
1523 if (parser.parseOperand(v) || parser.parseColonType(ty))
1524 return failure();
1525
1526 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1527 iteratedVars.push_back(v);
1528 iteratedTypes.push_back(ty);
1529 } else {
1530 plainVars.push_back(v);
1531 plainTypes.push_back(ty);
1532 }
1533 return success();
1534 });
1535}
1536
1537template <typename PrintPrefixFn>
1539 TypeRange iteratedTypes,
1540 ValueRange plainVars, TypeRange plainTypes,
1541 PrintPrefixFn &&printPrefixForPlain,
1542 PrintPrefixFn &&printPrefixForIterated) {
1543
1544 bool first = true;
1545 auto emit = [&](Value v, Type t, auto &&printPrefix) {
1546 if (!first)
1547 p << ", ";
1548 printPrefix(v, t);
1549 p << v << " : " << t;
1550 first = false;
1551 };
1552
1553 for (unsigned i = 0; i < iteratedVars.size(); ++i)
1554 emit(iteratedVars[i], iteratedTypes[i], printPrefixForIterated);
1555 for (unsigned i = 0; i < plainVars.size(); ++i)
1556 emit(plainVars[i], plainTypes[i], printPrefixForPlain);
1557}
1558
1559/// Verifies Reduction Clause
1560static LogicalResult
1561verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1562 OperandRange reductionVars,
1563 std::optional<ArrayRef<bool>> reductionByref) {
1564 if (!reductionVars.empty()) {
1565 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1566 return op->emitOpError()
1567 << "expected as many reduction symbol references "
1568 "as reduction variables";
1569 if (reductionByref && reductionByref->size() != reductionVars.size())
1570 return op->emitError() << "expected as many reduction variable by "
1571 "reference attributes as reduction variables";
1572 } else {
1573 if (reductionSyms)
1574 return op->emitOpError() << "unexpected reduction symbol references";
1575 return success();
1576 }
1577
1578 // TODO: The followings should be done in
1579 // SymbolUserOpInterface::verifySymbolUses.
1580 DenseSet<Value> accumulators;
1581 for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
1582 Value accum = std::get<0>(args);
1583
1584 if (!accumulators.insert(accum).second)
1585 return op->emitOpError() << "accumulator variable used more than once";
1586
1587 Type varType = accum.getType();
1588 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1589 auto decl =
1591 if (!decl)
1592 return op->emitOpError() << "expected symbol reference " << symbolRef
1593 << " to point to a reduction declaration";
1594
1595 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1596 return op->emitOpError()
1597 << "expected accumulator (" << varType
1598 << ") to be the same type as reduction declaration ("
1599 << decl.getAccumulatorType() << ")";
1600 }
1601
1602 return success();
1603}
1604
1605//===----------------------------------------------------------------------===//
1606// Parser, printer and verifier for Copyprivate
1607//===----------------------------------------------------------------------===//
1608
1609/// copyprivate-entry-list ::= copyprivate-entry
1610/// | copyprivate-entry-list `,` copyprivate-entry
1611/// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1612static ParseResult parseCopyprivate(
1613 OpAsmParser &parser,
1615 SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1617 if (failed(parser.parseCommaSeparatedList([&]() {
1618 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1619 parser.parseArrow() ||
1620 parser.parseAttribute(symsVec.emplace_back()) ||
1621 parser.parseColonType(copyprivateTypes.emplace_back()))
1622 return failure();
1623 return success();
1624 })))
1625 return failure();
1626 SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1627 copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1628 return success();
1629}
1630
1631/// Print Copyprivate clause
1633 OperandRange copyprivateVars,
1634 TypeRange copyprivateTypes,
1635 std::optional<ArrayAttr> copyprivateSyms) {
1636 if (!copyprivateSyms.has_value())
1637 return;
1638 llvm::interleaveComma(
1639 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1640 [&](const auto &args) {
1641 p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1642 << std::get<2>(args);
1643 });
1644}
1645
1646/// Verifies CopyPrivate Clause
1647static LogicalResult
1649 std::optional<ArrayAttr> copyprivateSyms) {
1650 size_t copyprivateSymsSize =
1651 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1652 if (copyprivateSymsSize != copyprivateVars.size())
1653 return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1654 << copyprivateVars.size()
1655 << ") and functions (= " << copyprivateSymsSize
1656 << "), both must be equal";
1657 if (!copyprivateSyms.has_value())
1658 return success();
1659
1660 for (auto copyprivateVarAndSym :
1661 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1662 auto symbolRef =
1663 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1664 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1665 funcOp;
1666 if (mlir::func::FuncOp mlirFuncOp =
1668 symbolRef))
1669 funcOp = mlirFuncOp;
1670 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1672 op, symbolRef))
1673 funcOp = llvmFuncOp;
1674
1675 auto getNumArguments = [&] {
1676 return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1677 };
1678
1679 auto getArgumentType = [&](unsigned i) {
1680 return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1681 *funcOp);
1682 };
1683
1684 if (!funcOp)
1685 return op->emitOpError() << "expected symbol reference " << symbolRef
1686 << " to point to a copy function";
1687
1688 if (getNumArguments() != 2)
1689 return op->emitOpError()
1690 << "expected copy function " << symbolRef << " to have 2 operands";
1691
1692 Type argTy = getArgumentType(0);
1693 if (argTy != getArgumentType(1))
1694 return op->emitOpError() << "expected copy function " << symbolRef
1695 << " arguments to have the same type";
1696
1697 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1698 if (argTy != varType)
1699 return op->emitOpError()
1700 << "expected copy function arguments' type (" << argTy
1701 << ") to be the same as copyprivate variable's type (" << varType
1702 << ")";
1703 }
1704
1705 return success();
1706}
1707
1708//===----------------------------------------------------------------------===//
1709// Parser, printer and verifier for DependVarList
1710//===----------------------------------------------------------------------===//
1711
1712/// depend-entry-list ::= depend-entry
1713/// | depend-entry-list `,` depend-entry
1714/// depend-entry ::= depend-kind `->` ssa-id `:` type
1715/// | depend-kind `->` ssa-id `:` iterated-type
1716static ParseResult parseDependVarList(
1717 OpAsmParser &parser,
1719 SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds,
1721 SmallVectorImpl<Type> &iteratedTypes, ArrayAttr &iteratedKinds) {
1724 if (failed(parser.parseCommaSeparatedList([&]() {
1725 StringRef keyword;
1726 OpAsmParser::UnresolvedOperand operand;
1727 Type ty;
1728 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1729 parser.parseOperand(operand) || parser.parseColonType(ty))
1730 return failure();
1731 std::optional<ClauseTaskDepend> keywordDepend =
1732 symbolizeClauseTaskDepend(keyword);
1733 if (!keywordDepend)
1734 return failure();
1735 auto kindAttr =
1736 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend);
1737 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1738 iteratedVars.push_back(operand);
1739 iteratedTypes.push_back(ty);
1740 iterKindsVec.push_back(kindAttr);
1741 } else {
1742 dependVars.push_back(operand);
1743 dependTypes.push_back(ty);
1744 kindsVec.push_back(kindAttr);
1745 }
1746 return success();
1747 })))
1748 return failure();
1749 SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1750 dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1751 SmallVector<Attribute> iterKinds(iterKindsVec.begin(), iterKindsVec.end());
1752 iteratedKinds = ArrayAttr::get(parser.getContext(), iterKinds);
1753 return success();
1754}
1755
1756/// Print Depend clause
1758 OperandRange dependVars, TypeRange dependTypes,
1759 std::optional<ArrayAttr> dependKinds,
1760 OperandRange iteratedVars,
1761 TypeRange iteratedTypes,
1762 std::optional<ArrayAttr> iteratedKinds) {
1763 bool first = true;
1764 auto printEntries = [&](OperandRange vars, TypeRange types,
1765 std::optional<ArrayAttr> kinds) {
1766 for (unsigned i = 0, e = vars.size(); i < e; ++i) {
1767 if (!first)
1768 p << ", ";
1769 p << stringifyClauseTaskDepend(
1770 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*kinds)[i])
1771 .getValue())
1772 << " -> " << vars[i] << " : " << types[i];
1773 first = false;
1774 }
1775 };
1776 printEntries(dependVars, dependTypes, dependKinds);
1777 printEntries(iteratedVars, iteratedTypes, iteratedKinds);
1778}
1779
1780/// Verifies Depend clause
1781static LogicalResult verifyDependVarList(Operation *op,
1782 std::optional<ArrayAttr> dependKinds,
1783 OperandRange dependVars,
1784 std::optional<ArrayAttr> iteratedKinds,
1785 OperandRange iteratedVars) {
1786 if (!dependVars.empty()) {
1787 if (!dependKinds || dependKinds->size() != dependVars.size())
1788 return op->emitOpError() << "expected as many depend values"
1789 " as depend variables";
1790 } else {
1791 if (dependKinds && !dependKinds->empty())
1792 return op->emitOpError() << "unexpected depend values";
1793 }
1794
1795 if (!iteratedVars.empty()) {
1796 if (!iteratedKinds || iteratedKinds->size() != iteratedVars.size())
1797 return op->emitOpError() << "expected as many depend iterated values"
1798 " as depend iterated variables";
1799 } else {
1800 if (iteratedKinds && !iteratedKinds->empty())
1801 return op->emitOpError() << "unexpected depend iterated values";
1802 }
1803
1804 return success();
1805}
1806
1807//===----------------------------------------------------------------------===//
1808// Parser, printer and verifier for Synchronization Hint (2.17.12)
1809//===----------------------------------------------------------------------===//
1810
1811/// Parses a Synchronization Hint clause. The value of hint is an integer
1812/// which is a combination of different hints from `omp_sync_hint_t`.
1813///
1814/// hint-clause = `hint` `(` hint-value `)`
1815static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1816 IntegerAttr &hintAttr) {
1817 StringRef hintKeyword;
1818 int64_t hint = 0;
1819 if (succeeded(parser.parseOptionalKeyword("none"))) {
1820 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1821 return success();
1822 }
1823 auto parseKeyword = [&]() -> ParseResult {
1824 if (failed(parser.parseKeyword(&hintKeyword)))
1825 return failure();
1826 if (hintKeyword == "uncontended")
1827 hint |= 1;
1828 else if (hintKeyword == "contended")
1829 hint |= 2;
1830 else if (hintKeyword == "nonspeculative")
1831 hint |= 4;
1832 else if (hintKeyword == "speculative")
1833 hint |= 8;
1834 else
1835 return parser.emitError(parser.getCurrentLocation())
1836 << hintKeyword << " is not a valid hint";
1837 return success();
1838 };
1839 if (parser.parseCommaSeparatedList(parseKeyword))
1840 return failure();
1841 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
1842 return success();
1843}
1844
1845/// Prints a Synchronization Hint clause
1847 IntegerAttr hintAttr) {
1848 int64_t hint = hintAttr.getInt();
1849
1850 if (hint == 0) {
1851 p << "none";
1852 return;
1853 }
1854
1855 // Helper function to get n-th bit from the right end of `value`
1856 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1857
1858 bool uncontended = bitn(hint, 0);
1859 bool contended = bitn(hint, 1);
1860 bool nonspeculative = bitn(hint, 2);
1861 bool speculative = bitn(hint, 3);
1862
1864 if (uncontended)
1865 hints.push_back("uncontended");
1866 if (contended)
1867 hints.push_back("contended");
1868 if (nonspeculative)
1869 hints.push_back("nonspeculative");
1870 if (speculative)
1871 hints.push_back("speculative");
1872
1873 llvm::interleaveComma(hints, p);
1874}
1875
1876/// Verifies a synchronization hint clause
1877static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
1878
1879 // Helper function to get n-th bit from the right end of `value`
1880 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1881
1882 bool uncontended = bitn(hint, 0);
1883 bool contended = bitn(hint, 1);
1884 bool nonspeculative = bitn(hint, 2);
1885 bool speculative = bitn(hint, 3);
1886
1887 if (uncontended && contended)
1888 return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1889 "omp_sync_hint_contended cannot be combined";
1890 if (nonspeculative && speculative)
1891 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1892 "omp_sync_hint_speculative cannot be combined.";
1893 return success();
1894}
1895
1896//===----------------------------------------------------------------------===//
1897// Parser, printer and verifier for Target
1898//===----------------------------------------------------------------------===//
1899
1900// Helper function to get bitwise AND of `value` and 'flag' then return it as a
1901// boolean
1902static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag) {
1903 return (value & flag) == flag;
1904}
1905
1906/// Parses a map_entries map type from a string format back into its numeric
1907/// value.
1908///
1909/// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
1910/// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
1911static ParseResult parseMapClause(OpAsmParser &parser,
1912 ClauseMapFlagsAttr &mapType) {
1913 ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
1914 // This simply verifies the correct keyword is read in, the
1915 // keyword itself is stored inside of the operation
1916 auto parseTypeAndMod = [&]() -> ParseResult {
1917 StringRef mapTypeMod;
1918 if (parser.parseKeyword(&mapTypeMod))
1919 return failure();
1920
1921 if (mapTypeMod == "always")
1922 mapTypeBits |= ClauseMapFlags::always;
1923
1924 if (mapTypeMod == "implicit")
1925 mapTypeBits |= ClauseMapFlags::implicit;
1926
1927 if (mapTypeMod == "ompx_hold")
1928 mapTypeBits |= ClauseMapFlags::ompx_hold;
1929
1930 if (mapTypeMod == "close")
1931 mapTypeBits |= ClauseMapFlags::close;
1932
1933 if (mapTypeMod == "present")
1934 mapTypeBits |= ClauseMapFlags::present;
1935
1936 if (mapTypeMod == "to")
1937 mapTypeBits |= ClauseMapFlags::to;
1938
1939 if (mapTypeMod == "from")
1940 mapTypeBits |= ClauseMapFlags::from;
1941
1942 if (mapTypeMod == "tofrom")
1943 mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
1944
1945 if (mapTypeMod == "delete")
1946 mapTypeBits |= ClauseMapFlags::del;
1947
1948 if (mapTypeMod == "storage")
1949 mapTypeBits |= ClauseMapFlags::storage;
1950
1951 if (mapTypeMod == "return_param")
1952 mapTypeBits |= ClauseMapFlags::return_param;
1953
1954 if (mapTypeMod == "private")
1955 mapTypeBits |= ClauseMapFlags::priv;
1956
1957 if (mapTypeMod == "literal")
1958 mapTypeBits |= ClauseMapFlags::literal;
1959
1960 if (mapTypeMod == "attach")
1961 mapTypeBits |= ClauseMapFlags::attach;
1962
1963 if (mapTypeMod == "attach_always")
1964 mapTypeBits |= ClauseMapFlags::attach_always;
1965
1966 if (mapTypeMod == "attach_never")
1967 mapTypeBits |= ClauseMapFlags::attach_never;
1968
1969 if (mapTypeMod == "attach_auto")
1970 mapTypeBits |= ClauseMapFlags::attach_auto;
1971
1972 if (mapTypeMod == "ref_ptr")
1973 mapTypeBits |= ClauseMapFlags::ref_ptr;
1974
1975 if (mapTypeMod == "ref_ptee")
1976 mapTypeBits |= ClauseMapFlags::ref_ptee;
1977
1978 if (mapTypeMod == "ref_ptr_ptee")
1979 mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
1980
1981 if (mapTypeMod == "is_device_ptr")
1982 mapTypeBits |= ClauseMapFlags::is_device_ptr;
1983
1984 return success();
1985 };
1986
1987 if (parser.parseCommaSeparatedList(parseTypeAndMod))
1988 return failure();
1989
1990 mapType =
1991 parser.getBuilder().getAttr<mlir::omp::ClauseMapFlagsAttr>(mapTypeBits);
1992
1993 return success();
1994}
1995
1996/// Prints a map_entries map type from its numeric value out into its string
1997/// format.
1998static void printMapClause(OpAsmPrinter &p, Operation *op,
1999 ClauseMapFlagsAttr mapType) {
2001 ClauseMapFlags mapFlags = mapType.getValue();
2002
2003 // handling of always, close, present placed at the beginning of the string
2004 // to aid readability
2005 if (mapTypeToBool(mapFlags, ClauseMapFlags::always))
2006 mapTypeStrs.push_back("always");
2007 if (mapTypeToBool(mapFlags, ClauseMapFlags::implicit))
2008 mapTypeStrs.push_back("implicit");
2009 if (mapTypeToBool(mapFlags, ClauseMapFlags::ompx_hold))
2010 mapTypeStrs.push_back("ompx_hold");
2011 if (mapTypeToBool(mapFlags, ClauseMapFlags::close))
2012 mapTypeStrs.push_back("close");
2013 if (mapTypeToBool(mapFlags, ClauseMapFlags::present))
2014 mapTypeStrs.push_back("present");
2015
2016 // special handling of to/from/tofrom/delete and release/alloc, release +
2017 // alloc are the abscense of one of the other flags, whereas tofrom requires
2018 // both the to and from flag to be set.
2019 bool to = mapTypeToBool(mapFlags, ClauseMapFlags::to);
2020 bool from = mapTypeToBool(mapFlags, ClauseMapFlags::from);
2021
2022 if (to && from)
2023 mapTypeStrs.push_back("tofrom");
2024 else if (from)
2025 mapTypeStrs.push_back("from");
2026 else if (to)
2027 mapTypeStrs.push_back("to");
2028
2029 if (mapTypeToBool(mapFlags, ClauseMapFlags::del))
2030 mapTypeStrs.push_back("delete");
2031 if (mapTypeToBool(mapFlags, ClauseMapFlags::return_param))
2032 mapTypeStrs.push_back("return_param");
2033 if (mapTypeToBool(mapFlags, ClauseMapFlags::storage))
2034 mapTypeStrs.push_back("storage");
2035 if (mapTypeToBool(mapFlags, ClauseMapFlags::priv))
2036 mapTypeStrs.push_back("private");
2037 if (mapTypeToBool(mapFlags, ClauseMapFlags::literal))
2038 mapTypeStrs.push_back("literal");
2039 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach))
2040 mapTypeStrs.push_back("attach");
2041 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_always))
2042 mapTypeStrs.push_back("attach_always");
2043 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_never))
2044 mapTypeStrs.push_back("attach_never");
2045 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_auto))
2046 mapTypeStrs.push_back("attach_auto");
2047 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr))
2048 mapTypeStrs.push_back("ref_ptr");
2049 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptee))
2050 mapTypeStrs.push_back("ref_ptee");
2051 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee))
2052 mapTypeStrs.push_back("ref_ptr_ptee");
2053 if (mapTypeToBool(mapFlags, ClauseMapFlags::is_device_ptr))
2054 mapTypeStrs.push_back("is_device_ptr");
2055 if (mapFlags == ClauseMapFlags::none)
2056 mapTypeStrs.push_back("none");
2057
2058 for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
2059 p << mapTypeStrs[i];
2060 if (i + 1 < mapTypeStrs.size()) {
2061 p << ", ";
2062 }
2063 }
2064}
2065
2066static ParseResult parseMembersIndex(OpAsmParser &parser,
2067 ArrayAttr &membersIdx) {
2068 SmallVector<Attribute> values, memberIdxs;
2069
2070 auto parseIndices = [&]() -> ParseResult {
2071 int64_t value;
2072 if (parser.parseInteger(value))
2073 return failure();
2074 values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
2075 APInt(64, value, /*isSigned=*/false)));
2076 return success();
2077 };
2078
2079 do {
2080 if (failed(parser.parseLSquare()))
2081 return failure();
2082
2083 if (parser.parseCommaSeparatedList(parseIndices))
2084 return failure();
2085
2086 if (failed(parser.parseRSquare()))
2087 return failure();
2088
2089 memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
2090 values.clear();
2091 } while (succeeded(parser.parseOptionalComma()));
2092
2093 if (!memberIdxs.empty())
2094 membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
2095
2096 return success();
2097}
2098
2099static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
2100 ArrayAttr membersIdx) {
2101 if (!membersIdx)
2102 return;
2103
2104 llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
2105 p << "[";
2106 auto memberIdx = cast<ArrayAttr>(v);
2107 llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
2108 p << cast<IntegerAttr>(v2).getInt();
2109 });
2110 p << "]";
2111 });
2112}
2113
2115 VariableCaptureKindAttr mapCaptureType) {
2116 std::string typeCapStr;
2117 llvm::raw_string_ostream typeCap(typeCapStr);
2118 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
2119 typeCap << "ByRef";
2120 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
2121 typeCap << "ByCopy";
2122 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
2123 typeCap << "VLAType";
2124 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
2125 typeCap << "This";
2126 p << typeCapStr;
2127}
2128
2129static ParseResult parseCaptureType(OpAsmParser &parser,
2130 VariableCaptureKindAttr &mapCaptureType) {
2131 StringRef mapCaptureKey;
2132 if (parser.parseKeyword(&mapCaptureKey))
2133 return failure();
2134
2135 if (mapCaptureKey == "This")
2136 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2137 parser.getContext(), mlir::omp::VariableCaptureKind::This);
2138 if (mapCaptureKey == "ByRef")
2139 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2140 parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
2141 if (mapCaptureKey == "ByCopy")
2142 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2143 parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
2144 if (mapCaptureKey == "VLAType")
2145 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2146 parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
2147
2148 return success();
2149}
2150
2151static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
2154
2155 for (auto mapOp : mapVars) {
2156 if (!mapOp.getDefiningOp())
2157 return emitError(op->getLoc(), "missing map operation");
2158
2159 if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
2160 mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
2161
2162 bool to = mapTypeToBool(mapTypeBits, ClauseMapFlags::to);
2163 bool from = mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
2164 bool del = mapTypeToBool(mapTypeBits, ClauseMapFlags::del);
2165
2166 bool always = mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
2167 bool close = mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
2168 bool implicit = mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
2169
2170 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
2171 return emitError(op->getLoc(),
2172 "to, from, tofrom and alloc map types are permitted");
2173
2174 if (isa<TargetEnterDataOp>(op) && (from || del))
2175 return emitError(op->getLoc(), "to and alloc map types are permitted");
2176
2177 if (isa<TargetExitDataOp>(op) && to)
2178 return emitError(op->getLoc(),
2179 "from, release and delete map types are permitted");
2180
2181 if (isa<TargetUpdateOp>(op)) {
2182 if (del) {
2183 return emitError(op->getLoc(),
2184 "at least one of to or from map types must be "
2185 "specified, other map types are not permitted");
2186 }
2187
2188 if (!to && !from) {
2189 return emitError(op->getLoc(),
2190 "at least one of to or from map types must be "
2191 "specified, other map types are not permitted");
2192 }
2193
2194 auto updateVar = mapInfoOp.getVarPtr();
2195
2196 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2197 (from && updateToVars.contains(updateVar))) {
2198 return emitError(
2199 op->getLoc(),
2200 "either to or from map types can be specified, not both");
2201 }
2202
2203 if (always || close || implicit) {
2204 return emitError(
2205 op->getLoc(),
2206 "present, mapper and iterator map type modifiers are permitted");
2207 }
2208
2209 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
2210 }
2211 } else if (!isa<DeclareMapperInfoOp>(op)) {
2212 return emitError(op->getLoc(),
2213 "map argument is not a map entry operation");
2214 }
2215 }
2216
2217 return success();
2218}
2219
2220static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
2221 std::optional<DenseI64ArrayAttr> privateMapIndices =
2222 targetOp.getPrivateMapsAttr();
2223
2224 // None of the private operands are mapped.
2225 if (!privateMapIndices.has_value() || !privateMapIndices.value())
2226 return success();
2227
2228 OperandRange privateVars = targetOp.getPrivateVars();
2229
2230 if (privateMapIndices.value().size() !=
2231 static_cast<int64_t>(privateVars.size()))
2232 return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
2233 "`private_maps` attribute mismatch");
2234
2235 return success();
2236}
2237
2238//===----------------------------------------------------------------------===//
2239// MapInfoOp
2240//===----------------------------------------------------------------------===//
2241
2242static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
2243 StringRef clauseName,
2244 OperandRange vars) {
2245 for (Value var : vars)
2246 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2247 return op->emitOpError()
2248 << "'" << clauseName
2249 << "' arguments must be defined by 'omp.map.info' ops";
2250 return success();
2251}
2252
2253LogicalResult MapInfoOp::verify() {
2254 if (getMapperId() &&
2256 *this, getMapperIdAttr())) {
2257 return emitError("invalid mapper id");
2258 }
2259
2260 if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
2261 return failure();
2262
2263 return success();
2264}
2265
2266//===----------------------------------------------------------------------===//
2267// TargetDataOp
2268//===----------------------------------------------------------------------===//
2269
2270void TargetDataOp::build(OpBuilder &builder, OperationState &state,
2271 const TargetDataOperands &clauses) {
2272 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2273 clauses.mapVars, clauses.useDeviceAddrVars,
2274 clauses.useDevicePtrVars);
2275}
2276
2277LogicalResult TargetDataOp::verify() {
2278 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
2279 getUseDeviceAddrVars().empty()) {
2280 return ::emitError(this->getLoc(),
2281 "At least one of map, use_device_ptr_vars, or "
2282 "use_device_addr_vars operand must be present");
2283 }
2284
2285 if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
2286 getUseDevicePtrVars())))
2287 return failure();
2288
2289 if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
2290 getUseDeviceAddrVars())))
2291 return failure();
2292
2293 return verifyMapClause(*this, getMapVars());
2294}
2295
2296//===----------------------------------------------------------------------===//
2297// TargetEnterDataOp
2298//===----------------------------------------------------------------------===//
2299
2300void TargetEnterDataOp::build(
2301 OpBuilder &builder, OperationState &state,
2302 const TargetEnterExitUpdateDataOperands &clauses) {
2303 MLIRContext *ctx = builder.getContext();
2304 TargetEnterDataOp::build(
2305 builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2306 clauses.dependVars, makeArrayAttr(ctx, clauses.dependIteratedKinds),
2307 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2308 clauses.nowait);
2309}
2310
2311LogicalResult TargetEnterDataOp::verify() {
2312 LogicalResult verifyDependVars =
2313 verifyDependVarList(*this, getDependKinds(), getDependVars(),
2314 getDependIteratedKinds(), getDependIterated());
2315 return failed(verifyDependVars) ? verifyDependVars
2316 : verifyMapClause(*this, getMapVars());
2317}
2318
2319//===----------------------------------------------------------------------===//
2320// TargetExitDataOp
2321//===----------------------------------------------------------------------===//
2322
2323void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
2324 const TargetEnterExitUpdateDataOperands &clauses) {
2325 MLIRContext *ctx = builder.getContext();
2326 TargetExitDataOp::build(
2327 builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2328 clauses.dependVars, makeArrayAttr(ctx, clauses.dependIteratedKinds),
2329 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2330 clauses.nowait);
2331}
2332
2333LogicalResult TargetExitDataOp::verify() {
2334 LogicalResult verifyDependVars =
2335 verifyDependVarList(*this, getDependKinds(), getDependVars(),
2336 getDependIteratedKinds(), getDependIterated());
2337 return failed(verifyDependVars) ? verifyDependVars
2338 : verifyMapClause(*this, getMapVars());
2339}
2340
2341//===----------------------------------------------------------------------===//
2342// TargetUpdateOp
2343//===----------------------------------------------------------------------===//
2344
2345void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
2346 const TargetEnterExitUpdateDataOperands &clauses) {
2347 MLIRContext *ctx = builder.getContext();
2348 TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2349 clauses.dependVars,
2350 makeArrayAttr(ctx, clauses.dependIteratedKinds),
2351 clauses.dependIterated, clauses.device, clauses.ifExpr,
2352 clauses.mapVars, clauses.nowait);
2353}
2354
2355LogicalResult TargetUpdateOp::verify() {
2356 LogicalResult verifyDependVars =
2357 verifyDependVarList(*this, getDependKinds(), getDependVars(),
2358 getDependIteratedKinds(), getDependIterated());
2359 return failed(verifyDependVars) ? verifyDependVars
2360 : verifyMapClause(*this, getMapVars());
2361}
2362
2363//===----------------------------------------------------------------------===//
2364// TargetOp
2365//===----------------------------------------------------------------------===//
2366
2367void TargetOp::build(OpBuilder &builder, OperationState &state,
2368 const TargetOperands &clauses) {
2369 MLIRContext *ctx = builder.getContext();
2370 // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
2371 // inReductionByref, inReductionSyms.
2372 TargetOp::build(
2373 builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.bare,
2374 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2375 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
2376 clauses.device, clauses.hasDeviceAddrVars, clauses.hostEvalVars,
2377 clauses.ifExpr,
2378 /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
2379 /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, clauses.mapVars,
2380 clauses.nowait, clauses.privateVars,
2381 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2382 clauses.threadLimitVars,
2383 /*private_maps=*/nullptr);
2384}
2385
2386LogicalResult TargetOp::verify() {
2387 if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars(),
2388 getDependIteratedKinds(),
2389 getDependIterated())))
2390 return failure();
2391
2392 if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
2393 getHasDeviceAddrVars())))
2394 return failure();
2395
2396 if (failed(verifyMapClause(*this, getMapVars())))
2397 return failure();
2398
2399 return verifyPrivateVarsMapping(*this);
2400}
2401
2402LogicalResult TargetOp::verifyRegions() {
2403 auto teamsOps = getOps<TeamsOp>();
2404 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2405 return emitError("target containing multiple 'omp.teams' nested ops");
2406
2407 // Check that host_eval values are only used in legal ways.
2408 Operation *capturedOp = getInnermostCapturedOmpOp();
2409 TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
2410 for (Value hostEvalArg :
2411 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2412 for (Operation *user : hostEvalArg.getUsers()) {
2413 if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
2414 // Check if used in num_teams_lower or any of num_teams_upper_vars
2415 if (hostEvalArg == teamsOp.getNumTeamsLower() ||
2416 llvm::is_contained(teamsOp.getNumTeamsUpperVars(), hostEvalArg) ||
2417 llvm::is_contained(teamsOp.getThreadLimitVars(), hostEvalArg))
2418 continue;
2419
2420 return emitOpError() << "host_eval argument only legal as 'num_teams' "
2421 "and 'thread_limit' in 'omp.teams'";
2422 }
2423 if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
2424 if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
2425 parallelOp->isAncestor(capturedOp) &&
2426 llvm::is_contained(parallelOp.getNumThreadsVars(), hostEvalArg))
2427 continue;
2428
2429 return emitOpError()
2430 << "host_eval argument only legal as 'num_threads' in "
2431 "'omp.parallel' when representing target SPMD";
2432 }
2433 if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2434 if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2435 loopNestOp.getOperation() == capturedOp &&
2436 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2437 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2438 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2439 continue;
2440
2441 return emitOpError() << "host_eval argument only legal as loop bounds "
2442 "and steps in 'omp.loop_nest' when trip count "
2443 "must be evaluated in the host";
2444 }
2445
2446 return emitOpError() << "host_eval argument illegal use in '"
2447 << user->getName() << "' operation";
2448 }
2449 }
2450 return success();
2451}
2452
2453static Operation *
2454findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
2455 llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
2456 assert(rootOp && "expected valid operation");
2457
2458 Dialect *ompDialect = rootOp->getDialect();
2459 Operation *capturedOp = nullptr;
2460 DominanceInfo domInfo;
2461
2462 // Process in pre-order to check operations from outermost to innermost,
2463 // ensuring we only enter the region of an operation if it meets the criteria
2464 // for being captured. We stop the exploration of nested operations as soon as
2465 // we process a region holding no operations to be captured.
2466 rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
2467 if (op == rootOp)
2468 return WalkResult::advance();
2469
2470 // Ignore operations of other dialects or omp operations with no regions,
2471 // because these will only be checked if they are siblings of an omp
2472 // operation that can potentially be captured.
2473 bool isOmpDialect = op->getDialect() == ompDialect;
2474 bool hasRegions = op->getNumRegions() > 0;
2475 if (!isOmpDialect || !hasRegions)
2476 return WalkResult::skip();
2477
2478 // This operation cannot be captured if it can be executed more than once
2479 // (i.e. its block's successors can reach it) or if it's not guaranteed to
2480 // be executed before all exits of the region (i.e. it doesn't dominate all
2481 // blocks with no successors reachable from the entry block).
2482 if (checkSingleMandatoryExec) {
2483 Region *parentRegion = op->getParentRegion();
2484 Block *parentBlock = op->getBlock();
2485
2486 for (Block *successor : parentBlock->getSuccessors())
2487 if (successor->isReachable(parentBlock))
2488 return WalkResult::interrupt();
2489
2490 for (Block &block : *parentRegion)
2491 if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2492 !domInfo.dominates(parentBlock, &block))
2493 return WalkResult::interrupt();
2494 }
2495
2496 // Don't capture this op if it has a not-allowed sibling, and stop recursing
2497 // into nested operations.
2498 for (Operation &sibling : op->getParentRegion()->getOps())
2499 if (&sibling != op && !siblingAllowedFn(&sibling))
2500 return WalkResult::interrupt();
2501
2502 // Don't continue capturing nested operations if we reach an omp.loop_nest.
2503 // Otherwise, process the contents of this operation.
2504 capturedOp = op;
2505 return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2507 });
2508
2509 return capturedOp;
2510}
2511
2512Operation *TargetOp::getInnermostCapturedOmpOp() {
2513 auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
2514
2515 // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2516 // effects, but don't include a memory write effect.
2517 return findCapturedOmpOp(
2518 *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2519 if (!sibling)
2520 return false;
2521
2522 if (ompDialect == sibling->getDialect())
2523 return sibling->hasTrait<OpTrait::IsTerminator>();
2524
2525 if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2527 effects;
2528 memOp.getEffects(effects);
2529 return !llvm::any_of(
2530 effects, [&](MemoryEffects::EffectInstance &effect) {
2531 return isa<MemoryEffects::Write>(effect.getEffect()) &&
2532 isa<SideEffects::AutomaticAllocationScopeResource>(
2533 effect.getResource());
2534 });
2535 }
2536 return true;
2537 });
2538}
2539
2540/// Check if we can promote SPMD kernel to No-Loop kernel.
2541static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp,
2542 WsloopOp *wsLoopOp) {
2543 // num_teams clause can break no-loop teams/threads assumption.
2544 if (!teamsOp.getNumTeamsUpperVars().empty())
2545 return false;
2546
2547 // Reduction kernels are slower in no-loop mode.
2548 if (teamsOp.getNumReductionVars())
2549 return false;
2550 if (wsLoopOp->getNumReductionVars())
2551 return false;
2552
2553 // Check if the user allows the promotion of kernels to no-loop mode.
2554 OffloadModuleInterface offloadMod =
2555 capturedOp->getParentOfType<omp::OffloadModuleInterface>();
2556 if (!offloadMod)
2557 return false;
2558 auto ompFlags = offloadMod.getFlags();
2559 if (!ompFlags)
2560 return false;
2561 return ompFlags.getAssumeTeamsOversubscription() &&
2562 ompFlags.getAssumeThreadsOversubscription();
2563}
2564
2565TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2566 // A non-null captured op is only valid if it resides inside of a TargetOp
2567 // and is the result of calling getInnermostCapturedOmpOp() on it.
2568 TargetOp targetOp =
2569 capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr;
2570 assert((!capturedOp ||
2571 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2572 "unexpected captured op");
2573
2574 // If it's not capturing a loop, it's a default target region.
2575 if (!isa_and_present<LoopNestOp>(capturedOp))
2576 return TargetRegionFlags::generic;
2577
2578 // Get the innermost non-simd loop wrapper.
2580 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2581 assert(!loopWrappers.empty());
2582
2583 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2584 if (isa<SimdOp>(innermostWrapper))
2585 innermostWrapper = std::next(innermostWrapper);
2586
2587 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2588 if (numWrappers != 1 && numWrappers != 2)
2589 return TargetRegionFlags::generic;
2590
2591 // Detect target-teams-distribute-parallel-wsloop[-simd].
2592 if (numWrappers == 2) {
2593 WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2594 if (!wsloopOp)
2595 return TargetRegionFlags::generic;
2596
2597 innermostWrapper = std::next(innermostWrapper);
2598 if (!isa<DistributeOp>(innermostWrapper))
2599 return TargetRegionFlags::generic;
2600
2601 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2602 if (!isa_and_present<ParallelOp>(parallelOp))
2603 return TargetRegionFlags::generic;
2604
2605 TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->getParentOp());
2606 if (!teamsOp)
2607 return TargetRegionFlags::generic;
2608
2609 if (teamsOp->getParentOp() == targetOp.getOperation()) {
2610 TargetRegionFlags result =
2611 TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2612 if (canPromoteToNoLoop(capturedOp, teamsOp, wsloopOp))
2613 result = result | TargetRegionFlags::no_loop;
2614 return result;
2615 }
2616 }
2617 // Detect target-teams-distribute[-simd] and target-teams-loop.
2618 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2619 Operation *teamsOp = (*innermostWrapper)->getParentOp();
2620 if (!isa_and_present<TeamsOp>(teamsOp))
2621 return TargetRegionFlags::generic;
2622
2623 if (teamsOp->getParentOp() != targetOp.getOperation())
2624 return TargetRegionFlags::generic;
2625
2626 if (isa<LoopOp>(innermostWrapper))
2627 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2628
2629 // Find single immediately nested captured omp.parallel and add spmd flag
2630 // (generic-spmd case).
2631 //
2632 // TODO: This shouldn't have to be done here, as it is too easy to break.
2633 // The openmp-opt pass should be updated to be able to promote kernels like
2634 // this from "Generic" to "Generic-SPMD". However, the use of the
2635 // `kmpc_distribute_static_loop` family of functions produced by the
2636 // OMPIRBuilder for these kernels prevents that from working.
2637 Dialect *ompDialect = targetOp->getDialect();
2638 Operation *nestedCapture = findCapturedOmpOp(
2639 capturedOp, /*checkSingleMandatoryExec=*/false,
2640 [&](Operation *sibling) {
2641 return sibling && (ompDialect != sibling->getDialect() ||
2642 sibling->hasTrait<OpTrait::IsTerminator>());
2643 });
2644
2645 TargetRegionFlags result =
2646 TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2647
2648 if (!nestedCapture)
2649 return result;
2650
2651 while (nestedCapture->getParentOp() != capturedOp)
2652 nestedCapture = nestedCapture->getParentOp();
2653
2654 return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2655 : result;
2656 }
2657 // Detect target-parallel-wsloop[-simd].
2658 else if (isa<WsloopOp>(innermostWrapper)) {
2659 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2660 if (!isa_and_present<ParallelOp>(parallelOp))
2661 return TargetRegionFlags::generic;
2662
2663 if (parallelOp->getParentOp() == targetOp.getOperation())
2664 return TargetRegionFlags::spmd;
2665 }
2666
2667 return TargetRegionFlags::generic;
2668}
2669
2670//===----------------------------------------------------------------------===//
2671// ParallelOp
2672//===----------------------------------------------------------------------===//
2673
2674void ParallelOp::build(OpBuilder &builder, OperationState &state,
2675 ArrayRef<NamedAttribute> attributes) {
2676 ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
2677 /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
2678 /*num_threads_vars=*/ValueRange(),
2679 /*private_vars=*/ValueRange(),
2680 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2681 /*proc_bind_kind=*/nullptr,
2682 /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
2683 /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
2684 state.addAttributes(attributes);
2685}
2686
2687void ParallelOp::build(OpBuilder &builder, OperationState &state,
2688 const ParallelOperands &clauses) {
2689 MLIRContext *ctx = builder.getContext();
2690 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2691 clauses.ifExpr, clauses.numThreadsVars, clauses.privateVars,
2692 makeArrayAttr(ctx, clauses.privateSyms),
2693 clauses.privateNeedsBarrier, clauses.procBindKind,
2694 clauses.reductionMod, clauses.reductionVars,
2695 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2696 makeArrayAttr(ctx, clauses.reductionSyms));
2697}
2698
2699template <typename OpType>
2700static LogicalResult verifyPrivateVarList(OpType &op) {
2701 auto privateVars = op.getPrivateVars();
2702 auto privateSyms = op.getPrivateSymsAttr();
2703
2704 if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
2705 return success();
2706
2707 auto numPrivateVars = privateVars.size();
2708 auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
2709
2710 if (numPrivateVars != numPrivateSyms)
2711 return op.emitError() << "inconsistent number of private variables and "
2712 "privatizer op symbols, private vars: "
2713 << numPrivateVars
2714 << " vs. privatizer op symbols: " << numPrivateSyms;
2715
2716 for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2717 Type varType = std::get<0>(privateVarInfo).getType();
2718 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2719 PrivateClauseOp privatizerOp =
2721
2722 if (privatizerOp == nullptr)
2723 return op.emitError() << "failed to lookup privatizer op with symbol: '"
2724 << privateSym << "'";
2725
2726 Type privatizerType = privatizerOp.getArgType();
2727
2728 if (privatizerType && (varType != privatizerType))
2729 return op.emitError()
2730 << "type mismatch between a "
2731 << (privatizerOp.getDataSharingType() ==
2732 DataSharingClauseType::Private
2733 ? "private"
2734 : "firstprivate")
2735 << " variable and its privatizer op, var type: " << varType
2736 << " vs. privatizer op type: " << privatizerType;
2737 }
2738
2739 return success();
2740}
2741
2742LogicalResult ParallelOp::verify() {
2743 if (getAllocateVars().size() != getAllocatorVars().size())
2744 return emitError(
2745 "expected equal sizes for allocate and allocator variables");
2746
2747 if (failed(verifyPrivateVarList(*this)))
2748 return failure();
2749
2750 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2751 getReductionByref());
2752}
2753
2754LogicalResult ParallelOp::verifyRegions() {
2755 auto distChildOps = getOps<DistributeOp>();
2756 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2757 if (numDistChildOps > 1)
2758 return emitError()
2759 << "multiple 'omp.distribute' nested inside of 'omp.parallel'";
2760
2761 if (numDistChildOps == 1) {
2762 if (!isComposite())
2763 return emitError()
2764 << "'omp.composite' attribute missing from composite operation";
2765
2766 auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
2767 Operation &distributeOp = **distChildOps.begin();
2768 for (Operation &childOp : getOps()) {
2769 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2770 continue;
2771
2772 if (!childOp.hasTrait<OpTrait::IsTerminator>())
2773 return emitError() << "unexpected OpenMP operation inside of composite "
2774 "'omp.parallel': "
2775 << childOp.getName();
2776 }
2777 } else if (isComposite()) {
2778 return emitError()
2779 << "'omp.composite' attribute present in non-composite operation";
2780 }
2781 return success();
2782}
2783
2784//===----------------------------------------------------------------------===//
2785// TeamsOp
2786//===----------------------------------------------------------------------===//
2787
2789 while ((op = op->getParentOp()))
2790 if (isa<OpenMPDialect>(op->getDialect()))
2791 return false;
2792 return true;
2793}
2794
2795void TeamsOp::build(OpBuilder &builder, OperationState &state,
2796 const TeamsOperands &clauses) {
2797 MLIRContext *ctx = builder.getContext();
2798 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2799 TeamsOp::build(
2800 builder, state, clauses.allocateVars, clauses.allocatorVars,
2801 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpperVars,
2802 /*private_vars=*/{}, /*private_syms=*/nullptr,
2803 /*private_needs_barrier=*/nullptr, clauses.reductionMod,
2804 clauses.reductionVars,
2805 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2806 makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimitVars);
2807}
2808
2809// Verify num_teams clause
2810static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower,
2811 OperandRange numTeamsUpperVars) {
2812 // If lower is specified, upper must have exactly one value
2813 if (numTeamsLower) {
2814 if (numTeamsUpperVars.size() != 1)
2815 return op->emitError(
2816 "expected exactly one num_teams upper bound when lower bound is "
2817 "specified");
2818 if (numTeamsLower.getType() != numTeamsUpperVars[0].getType())
2819 return op->emitError(
2820 "expected num_teams upper bound and lower bound to be "
2821 "the same type");
2822 }
2823
2824 return success();
2825}
2826
2827LogicalResult TeamsOp::verify() {
2828 // Check parent region
2829 // TODO If nested inside of a target region, also check that it does not
2830 // contain any statements, declarations or directives other than this
2831 // omp.teams construct. The issue is how to support the initialization of
2832 // this operation's own arguments (allow SSA values across omp.target?).
2833 Operation *op = getOperation();
2834 if (!isa<TargetOp>(op->getParentOp()) &&
2836 return emitError("expected to be nested inside of omp.target or not nested "
2837 "in any OpenMP dialect operations");
2838
2839 // Check for num_teams clause restrictions
2840 if (failed(verifyNumTeamsClause(op, this->getNumTeamsLower(),
2841 this->getNumTeamsUpperVars())))
2842 return failure();
2843
2844 // Check for allocate clause restrictions
2845 if (getAllocateVars().size() != getAllocatorVars().size())
2846 return emitError(
2847 "expected equal sizes for allocate and allocator variables");
2848
2849 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2850 getReductionByref());
2851}
2852
2853//===----------------------------------------------------------------------===//
2854// SectionOp
2855//===----------------------------------------------------------------------===//
2856
2857OperandRange SectionOp::getPrivateVars() {
2858 return getParentOp().getPrivateVars();
2859}
2860
2861OperandRange SectionOp::getReductionVars() {
2862 return getParentOp().getReductionVars();
2863}
2864
2865//===----------------------------------------------------------------------===//
2866// SectionsOp
2867//===----------------------------------------------------------------------===//
2868
2869void SectionsOp::build(OpBuilder &builder, OperationState &state,
2870 const SectionsOperands &clauses) {
2871 MLIRContext *ctx = builder.getContext();
2872 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2873 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2874 clauses.nowait, /*private_vars=*/{},
2875 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2876 clauses.reductionMod, clauses.reductionVars,
2877 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2878 makeArrayAttr(ctx, clauses.reductionSyms));
2879}
2880
2881LogicalResult SectionsOp::verify() {
2882 if (getAllocateVars().size() != getAllocatorVars().size())
2883 return emitError(
2884 "expected equal sizes for allocate and allocator variables");
2885
2886 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2887 getReductionByref());
2888}
2889
2890LogicalResult SectionsOp::verifyRegions() {
2891 for (auto &inst : *getRegion().begin()) {
2892 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2893 return emitOpError()
2894 << "expected omp.section op or terminator op inside region";
2895 }
2896 }
2897
2898 return success();
2899}
2900
2901//===----------------------------------------------------------------------===//
2902// SingleOp
2903//===----------------------------------------------------------------------===//
2904
2905void SingleOp::build(OpBuilder &builder, OperationState &state,
2906 const SingleOperands &clauses) {
2907 MLIRContext *ctx = builder.getContext();
2908 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2909 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2910 clauses.copyprivateVars,
2911 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2912 /*private_vars=*/{}, /*private_syms=*/nullptr,
2913 /*private_needs_barrier=*/nullptr);
2914}
2915
2916LogicalResult SingleOp::verify() {
2917 // Check for allocate clause restrictions
2918 if (getAllocateVars().size() != getAllocatorVars().size())
2919 return emitError(
2920 "expected equal sizes for allocate and allocator variables");
2921
2922 return verifyCopyprivateVarList(*this, getCopyprivateVars(),
2923 getCopyprivateSyms());
2924}
2925
2926//===----------------------------------------------------------------------===//
2927// WorkshareOp
2928//===----------------------------------------------------------------------===//
2929
2930void WorkshareOp::build(OpBuilder &builder, OperationState &state,
2931 const WorkshareOperands &clauses) {
2932 WorkshareOp::build(builder, state, clauses.nowait);
2933}
2934
2935//===----------------------------------------------------------------------===//
2936// WorkshareLoopWrapperOp
2937//===----------------------------------------------------------------------===//
2938
2939LogicalResult WorkshareLoopWrapperOp::verify() {
2940 if (!(*this)->getParentOfType<WorkshareOp>())
2941 return emitOpError() << "must be nested in an omp.workshare";
2942 return success();
2943}
2944
2945LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2946 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2947 getNestedWrapper())
2948 return emitOpError() << "expected to be a standalone loop wrapper";
2949
2950 return success();
2951}
2952
2953//===----------------------------------------------------------------------===//
2954// LoopWrapperInterface
2955//===----------------------------------------------------------------------===//
2956
2957LogicalResult LoopWrapperInterface::verifyImpl() {
2958 Operation *op = this->getOperation();
2959 if (!op->hasTrait<OpTrait::NoTerminator>() ||
2961 return emitOpError() << "loop wrapper must also have the `NoTerminator` "
2962 "and `SingleBlock` traits";
2963
2964 if (op->getNumRegions() != 1)
2965 return emitOpError() << "loop wrapper does not contain exactly one region";
2966
2967 Region &region = op->getRegion(0);
2968 if (range_size(region.getOps()) != 1)
2969 return emitOpError()
2970 << "loop wrapper does not contain exactly one nested op";
2971
2972 Operation &firstOp = *region.op_begin();
2973 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2974 return emitOpError() << "nested in loop wrapper is not another loop "
2975 "wrapper or `omp.loop_nest`";
2976
2977 return success();
2978}
2979
2980//===----------------------------------------------------------------------===//
2981// LoopOp
2982//===----------------------------------------------------------------------===//
2983
2984void LoopOp::build(OpBuilder &builder, OperationState &state,
2985 const LoopOperands &clauses) {
2986 MLIRContext *ctx = builder.getContext();
2987
2988 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2989 makeArrayAttr(ctx, clauses.privateSyms),
2990 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2991 clauses.reductionMod, clauses.reductionVars,
2992 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2993 makeArrayAttr(ctx, clauses.reductionSyms));
2994}
2995
2996LogicalResult LoopOp::verify() {
2997 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2998 getReductionByref());
2999}
3000
3001LogicalResult LoopOp::verifyRegions() {
3002 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
3003 getNestedWrapper())
3004 return emitOpError() << "expected to be a standalone loop wrapper";
3005
3006 return success();
3007}
3008
3009//===----------------------------------------------------------------------===//
3010// WsloopOp
3011//===----------------------------------------------------------------------===//
3012
3013void WsloopOp::build(OpBuilder &builder, OperationState &state,
3014 ArrayRef<NamedAttribute> attributes) {
3015 build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
3016 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
3017 /*linear_var_types*/ nullptr, /*linear_modifiers=*/nullptr,
3018 /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
3019 /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
3020 /*private_needs_barrier=*/false,
3021 /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
3022 /*reduction_byref=*/nullptr,
3023 /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
3024 /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
3025 /*schedule_simd=*/false);
3026 state.addAttributes(attributes);
3027}
3028
3029void WsloopOp::build(OpBuilder &builder, OperationState &state,
3030 const WsloopOperands &clauses) {
3031 MLIRContext *ctx = builder.getContext();
3032 // TODO: Store clauses in op: allocateVars, allocatorVars
3033 WsloopOp::build(
3034 builder, state,
3035 /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
3036 clauses.linearStepVars, clauses.linearVarTypes, clauses.linearModifiers,
3037 clauses.nowait, clauses.order, clauses.orderMod, clauses.ordered,
3038 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
3039 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3040 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3041 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
3042 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
3043}
3044
3045LogicalResult WsloopOp::verify() {
3046 if (failed(
3047 verifyLinearModifiers(*this, getLinearModifiers(), getLinearVars())))
3048 return failure();
3049 if (getLinearVars().size() &&
3050 getLinearVarTypes().value().size() != getLinearVars().size())
3051 return emitError() << "Ill-formed type attributes for linear variables";
3052 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
3053 getReductionByref());
3054}
3055
3056LogicalResult WsloopOp::verifyRegions() {
3057 bool isCompositeChildLeaf =
3058 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3059
3060 if (LoopWrapperInterface nested = getNestedWrapper()) {
3061 if (!isComposite())
3062 return emitError()
3063 << "'omp.composite' attribute missing from composite wrapper";
3064
3065 // Check for the allowed leaf constructs that may appear in a composite
3066 // construct directly after DO/FOR.
3067 if (!isa<SimdOp>(nested))
3068 return emitError() << "only supported nested wrapper is 'omp.simd'";
3069
3070 } else if (isComposite() && !isCompositeChildLeaf) {
3071 return emitError()
3072 << "'omp.composite' attribute present in non-composite wrapper";
3073 } else if (!isComposite() && isCompositeChildLeaf) {
3074 return emitError()
3075 << "'omp.composite' attribute missing from composite wrapper";
3076 }
3077
3078 return success();
3079}
3080
3081//===----------------------------------------------------------------------===//
3082// Simd construct [2.9.3.1]
3083//===----------------------------------------------------------------------===//
3084
3085void SimdOp::build(OpBuilder &builder, OperationState &state,
3086 const SimdOperands &clauses) {
3087 MLIRContext *ctx = builder.getContext();
3088 SimdOp::build(builder, state, clauses.alignedVars,
3089 makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
3090 clauses.linearVars, clauses.linearStepVars,
3091 clauses.linearVarTypes, clauses.linearModifiers,
3092 clauses.nontemporalVars, clauses.order, clauses.orderMod,
3093 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
3094 clauses.privateNeedsBarrier, clauses.reductionMod,
3095 clauses.reductionVars,
3096 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3097 makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
3098 clauses.simdlen);
3099}
3100
3101LogicalResult SimdOp::verify() {
3102 if (getSimdlen().has_value() && getSafelen().has_value() &&
3103 getSimdlen().value() > getSafelen().value())
3104 return emitOpError()
3105 << "simdlen clause and safelen clause are both present, but the "
3106 "simdlen value is not less than or equal to safelen value";
3107
3108 if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
3109 return failure();
3110
3111 if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
3112 return failure();
3113
3114 if (failed(
3115 verifyLinearModifiers(*this, getLinearModifiers(), getLinearVars())))
3116 return failure();
3117
3118 bool isCompositeChildLeaf =
3119 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3120
3121 if (!isComposite() && isCompositeChildLeaf)
3122 return emitError()
3123 << "'omp.composite' attribute missing from composite wrapper";
3124
3125 if (isComposite() && !isCompositeChildLeaf)
3126 return emitError()
3127 << "'omp.composite' attribute present in non-composite wrapper";
3128
3129 // Firstprivate is not allowed for SIMD in the standard. Check that none of
3130 // the private decls are for firstprivate.
3131 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
3132 if (privateSyms) {
3133 for (const Attribute &sym : *privateSyms) {
3134 auto symRef = cast<SymbolRefAttr>(sym);
3135 omp::PrivateClauseOp privatizer =
3137 getOperation(), symRef);
3138 if (!privatizer)
3139 return emitError() << "Cannot find privatizer '" << symRef << "'";
3140 if (privatizer.getDataSharingType() ==
3141 DataSharingClauseType::FirstPrivate)
3142 return emitError() << "FIRSTPRIVATE cannot be used with SIMD";
3143 }
3144 }
3145
3146 if (getLinearVars().size() &&
3147 getLinearVarTypes().value().size() != getLinearVars().size())
3148 return emitError() << "Ill-formed type attributes for linear variables";
3149 return success();
3150}
3151
3152LogicalResult SimdOp::verifyRegions() {
3153 if (getNestedWrapper())
3154 return emitOpError() << "must wrap an 'omp.loop_nest' directly";
3155
3156 return success();
3157}
3158
3159//===----------------------------------------------------------------------===//
3160// Distribute construct [2.9.4.1]
3161//===----------------------------------------------------------------------===//
3162
3163void DistributeOp::build(OpBuilder &builder, OperationState &state,
3164 const DistributeOperands &clauses) {
3165 DistributeOp::build(builder, state, clauses.allocateVars,
3166 clauses.allocatorVars, clauses.distScheduleStatic,
3167 clauses.distScheduleChunkSize, clauses.order,
3168 clauses.orderMod, clauses.privateVars,
3169 makeArrayAttr(builder.getContext(), clauses.privateSyms),
3170 clauses.privateNeedsBarrier);
3171}
3172
3173LogicalResult DistributeOp::verify() {
3174 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
3175 return emitOpError() << "chunk size set without "
3176 "dist_schedule_static being present";
3177
3178 if (getAllocateVars().size() != getAllocatorVars().size())
3179 return emitError(
3180 "expected equal sizes for allocate and allocator variables");
3181
3182 return success();
3183}
3184
3185LogicalResult DistributeOp::verifyRegions() {
3186 if (LoopWrapperInterface nested = getNestedWrapper()) {
3187 if (!isComposite())
3188 return emitError()
3189 << "'omp.composite' attribute missing from composite wrapper";
3190 // Check for the allowed leaf constructs that may appear in a composite
3191 // construct directly after DISTRIBUTE.
3192 if (isa<WsloopOp>(nested)) {
3193 Operation *parentOp = (*this)->getParentOp();
3194 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
3195 !cast<ComposableOpInterface>(parentOp).isComposite()) {
3196 return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
3197 "when a composite 'omp.parallel' is the direct "
3198 "parent";
3199 }
3200 } else if (!isa<SimdOp>(nested))
3201 return emitError() << "only supported nested wrappers are 'omp.simd' and "
3202 "'omp.wsloop'";
3203 } else if (isComposite()) {
3204 return emitError()
3205 << "'omp.composite' attribute present in non-composite wrapper";
3206 }
3207
3208 return success();
3209}
3210
3211//===----------------------------------------------------------------------===//
3212// DeclareMapperOp / DeclareMapperInfoOp
3213//===----------------------------------------------------------------------===//
3214
3215LogicalResult DeclareMapperInfoOp::verify() {
3216 return verifyMapClause(*this, getMapVars());
3217}
3218
3219LogicalResult DeclareMapperOp::verifyRegions() {
3220 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
3221 getRegion().getBlocks().front().getTerminator()))
3222 return emitOpError() << "expected terminator to be a DeclareMapperInfoOp";
3223
3224 return success();
3225}
3226
3227//===----------------------------------------------------------------------===//
3228// DeclareReductionOp
3229//===----------------------------------------------------------------------===//
3230
3231LogicalResult DeclareReductionOp::verifyRegions() {
3232 if (!getAllocRegion().empty()) {
3233 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3234 if (yieldOp.getResults().size() != 1 ||
3235 yieldOp.getResults().getTypes()[0] != getType())
3236 return emitOpError() << "expects alloc region to yield a value "
3237 "of the reduction type";
3238 }
3239 }
3240
3241 if (getInitializerRegion().empty())
3242 return emitOpError() << "expects non-empty initializer region";
3243 Block &initializerEntryBlock = getInitializerRegion().front();
3244
3245 if (initializerEntryBlock.getNumArguments() == 1) {
3246 if (!getAllocRegion().empty())
3247 return emitOpError() << "expects two arguments to the initializer region "
3248 "when an allocation region is used";
3249 } else if (initializerEntryBlock.getNumArguments() == 2) {
3250 if (getAllocRegion().empty())
3251 return emitOpError() << "expects one argument to the initializer region "
3252 "when no allocation region is used";
3253 } else {
3254 return emitOpError()
3255 << "expects one or two arguments to the initializer region";
3256 }
3257
3258 for (mlir::Value arg : initializerEntryBlock.getArguments())
3259 if (arg.getType() != getType())
3260 return emitOpError() << "expects initializer region argument to match "
3261 "the reduction type";
3262
3263 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3264 if (yieldOp.getResults().size() != 1 ||
3265 yieldOp.getResults().getTypes()[0] != getType())
3266 return emitOpError() << "expects initializer region to yield a value "
3267 "of the reduction type";
3268 }
3269
3270 if (getReductionRegion().empty())
3271 return emitOpError() << "expects non-empty reduction region";
3272 Block &reductionEntryBlock = getReductionRegion().front();
3273 if (reductionEntryBlock.getNumArguments() != 2 ||
3274 reductionEntryBlock.getArgumentTypes()[0] !=
3275 reductionEntryBlock.getArgumentTypes()[1] ||
3276 reductionEntryBlock.getArgumentTypes()[0] != getType())
3277 return emitOpError() << "expects reduction region with two arguments of "
3278 "the reduction type";
3279 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3280 if (yieldOp.getResults().size() != 1 ||
3281 yieldOp.getResults().getTypes()[0] != getType())
3282 return emitOpError() << "expects reduction region to yield a value "
3283 "of the reduction type";
3284 }
3285
3286 if (!getAtomicReductionRegion().empty()) {
3287 Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
3288 if (atomicReductionEntryBlock.getNumArguments() != 2 ||
3289 atomicReductionEntryBlock.getArgumentTypes()[0] !=
3290 atomicReductionEntryBlock.getArgumentTypes()[1])
3291 return emitOpError() << "expects atomic reduction region with two "
3292 "arguments of the same type";
3293 auto ptrType = llvm::dyn_cast<PointerLikeType>(
3294 atomicReductionEntryBlock.getArgumentTypes()[0]);
3295 if (!ptrType ||
3296 (ptrType.getElementType() && ptrType.getElementType() != getType()))
3297 return emitOpError() << "expects atomic reduction region arguments to "
3298 "be accumulators containing the reduction type";
3299 }
3300
3301 if (getCleanupRegion().empty())
3302 return success();
3303 Block &cleanupEntryBlock = getCleanupRegion().front();
3304 if (cleanupEntryBlock.getNumArguments() != 1 ||
3305 cleanupEntryBlock.getArgument(0).getType() != getType())
3306 return emitOpError() << "expects cleanup region with one argument "
3307 "of the reduction type";
3308
3309 return success();
3310}
3311
3312//===----------------------------------------------------------------------===//
3313// TaskOp
3314//===----------------------------------------------------------------------===//
3315
3316void TaskOp::build(OpBuilder &builder, OperationState &state,
3317 const TaskOperands &clauses) {
3318 MLIRContext *ctx = builder.getContext();
3319 TaskOp::build(
3320 builder, state, clauses.iterated, clauses.affinityVars,
3321 clauses.allocateVars, clauses.allocatorVars,
3322 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
3323 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
3324 clauses.final, clauses.ifExpr, clauses.inReductionVars,
3325 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3326 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3327 clauses.priority, /*private_vars=*/clauses.privateVars,
3328 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3329 clauses.privateNeedsBarrier, clauses.untied, clauses.eventHandle);
3330}
3331
3332LogicalResult TaskOp::verify() {
3333 LogicalResult verifyDependVars =
3334 verifyDependVarList(*this, getDependKinds(), getDependVars(),
3335 getDependIteratedKinds(), getDependIterated());
3336 return failed(verifyDependVars)
3337 ? verifyDependVars
3338 : verifyReductionVarList(*this, getInReductionSyms(),
3339 getInReductionVars(),
3340 getInReductionByref());
3341}
3342
3343//===----------------------------------------------------------------------===//
3344// TaskgroupOp
3345//===----------------------------------------------------------------------===//
3346
3347void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
3348 const TaskgroupOperands &clauses) {
3349 MLIRContext *ctx = builder.getContext();
3350 TaskgroupOp::build(builder, state, clauses.allocateVars,
3351 clauses.allocatorVars, clauses.taskReductionVars,
3352 makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
3353 makeArrayAttr(ctx, clauses.taskReductionSyms));
3354}
3355
3356LogicalResult TaskgroupOp::verify() {
3357 return verifyReductionVarList(*this, getTaskReductionSyms(),
3358 getTaskReductionVars(),
3359 getTaskReductionByref());
3360}
3361
3362//===----------------------------------------------------------------------===//
3363// TaskloopContextOp
3364//===----------------------------------------------------------------------===//
3365
3366void TaskloopContextOp::build(OpBuilder &builder, OperationState &state,
3367 const TaskloopContextOperands &clauses) {
3368 MLIRContext *ctx = builder.getContext();
3369 TaskloopContextOp::build(
3370 builder, state, clauses.allocateVars, clauses.allocatorVars,
3371 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3372 clauses.inReductionVars,
3373 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3374 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3375 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3376 /*private_vars=*/clauses.privateVars,
3377 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3378 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3379 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3380 makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
3381}
3382
3383TaskloopWrapperOp TaskloopContextOp::getLoopOp() {
3384 return cast<TaskloopWrapperOp>(
3385 *llvm::find_if(getRegion().front(), [](mlir::Operation &op) {
3386 return isa<TaskloopWrapperOp>(op);
3387 }));
3388}
3389
3390LogicalResult TaskloopContextOp::verify() {
3391 if (getAllocateVars().size() != getAllocatorVars().size())
3392 return emitError(
3393 "expected equal sizes for allocate and allocator variables");
3394 if (failed(verifyReductionVarList(*this, getReductionSyms(),
3395 getReductionVars(), getReductionByref())) ||
3396 failed(verifyReductionVarList(*this, getInReductionSyms(),
3397 getInReductionVars(),
3398 getInReductionByref())))
3399 return failure();
3400
3401 if (!getReductionVars().empty() && getNogroup())
3402 return emitError("if a reduction clause is present on the taskloop "
3403 "directive, the nogroup clause must not be specified");
3404 for (auto var : getReductionVars()) {
3405 if (llvm::is_contained(getInReductionVars(), var))
3406 return emitError("the same list item cannot appear in both a reduction "
3407 "and an in_reduction clause");
3408 }
3409
3410 if (getGrainsize() && getNumTasks()) {
3411 return emitError(
3412 "the grainsize clause and num_tasks clause are mutually exclusive and "
3413 "may not appear on the same taskloop directive");
3414 }
3415
3416 return success();
3417}
3418
3419LogicalResult TaskloopContextOp::verifyRegions() {
3420 Region &region = getRegion();
3421 if (region.empty())
3422 return emitOpError() << "expected non-empty region";
3423
3424 auto count = llvm::count_if(region.front(), [](mlir::Operation &op) {
3425 return isa<TaskloopWrapperOp>(op);
3426 });
3427 if (count != 1)
3428 return emitOpError()
3429 << "expected exactly 1 TaskloopWrapperOp directly nested in "
3430 "the region, but "
3431 << count << " were found";
3432 TaskloopWrapperOp loopWrapperOp = getLoopOp();
3433
3434 auto loopNestOp = dyn_cast<LoopNestOp>(loopWrapperOp.getWrappedLoop());
3435 // This will fail the verifier for TaskloopWrapperOp and print an error
3436 // message there.
3437 if (!loopNestOp)
3438 return failure();
3439
3440 auto isDefinedInTaskloopContext = [&](Value value) {
3441 // A region is considered an ancestor of itself
3442 return region.isAncestor(value.getParentRegion());
3443 };
3444 auto hasTaskloopLocalBound = [&](OperandRange range) {
3445 return llvm::any_of(range, isDefinedInTaskloopContext);
3446 };
3447
3448 if (hasTaskloopLocalBound(loopNestOp.getLoopLowerBounds()) ||
3449 hasTaskloopLocalBound(loopNestOp.getLoopUpperBounds()) ||
3450 hasTaskloopLocalBound(loopNestOp.getLoopSteps())) {
3451 return emitOpError() << "expects loop bounds and steps to be defined "
3452 "outside of the taskloop.context region";
3453 }
3454
3455 return success();
3456}
3457
3458//===----------------------------------------------------------------------===//
3459// TaskloopWrapperOp
3460//===----------------------------------------------------------------------===//
3461
3462void TaskloopWrapperOp::build(OpBuilder &builder, OperationState &state,
3463 const TaskloopWrapperOperands &clauses) {
3464 TaskloopWrapperOp::build(builder, state);
3465}
3466
3467TaskloopContextOp TaskloopWrapperOp::getTaskloopContext() {
3468 return dyn_cast<TaskloopContextOp>(getOperation()->getParentOp());
3469}
3470
3471LogicalResult TaskloopWrapperOp::verify() {
3472 TaskloopContextOp context = getTaskloopContext();
3473 if (!context)
3474 return emitOpError() << "expected to be nested in a taskloop context op";
3475 return success();
3476}
3477
3478LogicalResult TaskloopWrapperOp::verifyRegions() {
3479 if (LoopWrapperInterface nested = getNestedWrapper()) {
3480 if (!isComposite())
3481 return emitError()
3482 << "'omp.composite' attribute missing from composite wrapper";
3483
3484 // Check for the allowed leaf constructs that may appear in a composite
3485 // construct directly after TASKLOOP.
3486 if (!isa<SimdOp>(nested))
3487 return emitError() << "only supported nested wrapper is 'omp.simd'";
3488 } else if (isComposite()) {
3489 return emitError()
3490 << "'omp.composite' attribute present in non-composite wrapper";
3491 }
3492
3493 return success();
3494}
3495
3496//===----------------------------------------------------------------------===//
3497// LoopNestOp
3498//===----------------------------------------------------------------------===//
3499
3500ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
3501 // Parse an opening `(` followed by induction variables followed by `)`
3504 Type loopVarType;
3506 parser.parseColonType(loopVarType) ||
3507 // Parse loop bounds.
3508 parser.parseEqual() ||
3509 parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
3510 parser.parseKeyword("to") ||
3511 parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
3512 return failure();
3513
3514 for (auto &iv : ivs)
3515 iv.type = loopVarType;
3516
3517 auto *ctx = parser.getBuilder().getContext();
3518 // Parse "inclusive" flag.
3519 if (succeeded(parser.parseOptionalKeyword("inclusive")))
3520 result.addAttribute("loop_inclusive", UnitAttr::get(ctx));
3521
3522 // Parse step values.
3524 if (parser.parseKeyword("step") ||
3525 parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
3526 return failure();
3527
3528 // Parse collapse
3529 int64_t value = 0;
3530 if (!parser.parseOptionalKeyword("collapse") &&
3531 (parser.parseLParen() || parser.parseInteger(value) ||
3532 parser.parseRParen()))
3533 return failure();
3534 if (value > 1)
3535 result.addAttribute(
3536 "collapse_num_loops",
3537 IntegerAttr::get(parser.getBuilder().getI64Type(), value));
3538
3539 // Parse tiles
3541 auto parseTiles = [&]() -> ParseResult {
3542 int64_t tile;
3543 if (parser.parseInteger(tile))
3544 return failure();
3545 tiles.push_back(tile);
3546 return success();
3547 };
3548
3549 if (!parser.parseOptionalKeyword("tiles") &&
3550 (parser.parseLParen() || parser.parseCommaSeparatedList(parseTiles) ||
3551 parser.parseRParen()))
3552 return failure();
3553
3554 if (tiles.size() > 0)
3555 result.addAttribute("tile_sizes", DenseI64ArrayAttr::get(ctx, tiles));
3556
3557 // Parse the body.
3558 Region *region = result.addRegion();
3559 if (parser.parseRegion(*region, ivs))
3560 return failure();
3561
3562 // Resolve operands.
3563 if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
3564 parser.resolveOperands(ubs, loopVarType, result.operands) ||
3565 parser.resolveOperands(steps, loopVarType, result.operands))
3566 return failure();
3567
3568 // Parse the optional attribute list.
3569 return parser.parseOptionalAttrDict(result.attributes);
3570}
3571
3572void LoopNestOp::print(OpAsmPrinter &p) {
3573 Region &region = getRegion();
3574 auto args = region.getArguments();
3575 p << " (" << args << ") : " << args[0].getType() << " = ("
3576 << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
3577 if (getLoopInclusive())
3578 p << "inclusive ";
3579 p << "step (" << getLoopSteps() << ") ";
3580 if (int64_t numCollapse = getCollapseNumLoops())
3581 if (numCollapse > 1)
3582 p << "collapse(" << numCollapse << ") ";
3583
3584 if (const auto tiles = getTileSizes())
3585 p << "tiles(" << tiles.value() << ") ";
3586
3587 p.printRegion(region, /*printEntryBlockArgs=*/false);
3588}
3589
3590void LoopNestOp::build(OpBuilder &builder, OperationState &state,
3591 const LoopNestOperands &clauses) {
3592 MLIRContext *ctx = builder.getContext();
3593 LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3594 clauses.loopLowerBounds, clauses.loopUpperBounds,
3595 clauses.loopSteps, clauses.loopInclusive,
3596 makeDenseI64ArrayAttr(ctx, clauses.tileSizes));
3597}
3598
3599LogicalResult LoopNestOp::verify() {
3600 if (getLoopLowerBounds().empty())
3601 return emitOpError() << "must represent at least one loop";
3602
3603 if (getLoopLowerBounds().size() != getIVs().size())
3604 return emitOpError() << "number of range arguments and IVs do not match";
3605
3606 for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3607 if (lb.getType() != iv.getType())
3608 return emitOpError()
3609 << "range argument type does not match corresponding IV type";
3610 }
3611
3612 uint64_t numIVs = getIVs().size();
3613
3614 if (const auto &numCollapse = getCollapseNumLoops())
3615 if (numCollapse > numIVs)
3616 return emitOpError()
3617 << "collapse value is larger than the number of loops";
3618
3619 if (const auto &tiles = getTileSizes())
3620 if (tiles.value().size() > numIVs)
3621 return emitOpError() << "too few canonical loops for tile dimensions";
3622
3623 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3624 return emitOpError() << "expects parent op to be a loop wrapper";
3625
3626 return success();
3627}
3628
3629void LoopNestOp::gatherWrappers(
3631 Operation *parent = (*this)->getParentOp();
3632 while (auto wrapper =
3633 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3634 wrappers.push_back(wrapper);
3635 parent = parent->getParentOp();
3636 }
3637}
3638
3639//===----------------------------------------------------------------------===//
3640// OpenMP canonical loop handling
3641//===----------------------------------------------------------------------===//
3642
3643std::tuple<NewCliOp, OpOperand *, OpOperand *>
3644mlir::omp ::decodeCli(Value cli) {
3645
3646 // Defining a CLI for a generated loop is optional; if there is none then
3647 // there is no followup-tranformation
3648 if (!cli)
3649 return {{}, nullptr, nullptr};
3650
3651 assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3652 "Unexpected type of cli");
3653
3654 NewCliOp create = cast<NewCliOp>(cli.getDefiningOp());
3655 OpOperand *gen = nullptr;
3656 OpOperand *cons = nullptr;
3657 for (OpOperand &use : cli.getUses()) {
3658 auto op = cast<LoopTransformationInterface>(use.getOwner());
3659
3660 unsigned opnum = use.getOperandNumber();
3661 if (op.isGeneratee(opnum)) {
3662 assert(!gen && "Each CLI may have at most one def");
3663 gen = &use;
3664 } else if (op.isApplyee(opnum)) {
3665 assert(!cons && "Each CLI may have at most one consumer");
3666 cons = &use;
3667 } else {
3668 llvm_unreachable("Unexpected operand for a CLI");
3669 }
3670 }
3671
3672 return {create, gen, cons};
3673}
3674
3675void NewCliOp::build(::mlir::OpBuilder &odsBuilder,
3676 ::mlir::OperationState &odsState) {
3677 odsState.addTypes(CanonicalLoopInfoType::get(odsBuilder.getContext()));
3678}
3679
3680void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3681 Value result = getResult();
3682 auto [newCli, gen, cons] = decodeCli(result);
3683
3684 // Structured binding `gen` cannot be captured in lambdas before C++20
3685 OpOperand *generator = gen;
3686
3687 // Derive the CLI variable name from its generator:
3688 // * "canonloop" for omp.canonical_loop
3689 // * custom name for loop transformation generatees
3690 // * "cli" as fallback if no generator
3691 // * "_r<idx>" suffix for nested loops, where <idx> is the sequential order
3692 // at that level
3693 // * "_s<idx>" suffix for operations with multiple regions, where <idx> is
3694 // the index of that region
3695 std::string cliName{"cli"};
3696 if (gen) {
3697 cliName =
3699 .Case([&](CanonicalLoopOp op) {
3700 return generateLoopNestingName("canonloop", op);
3701 })
3702 .Case([&](UnrollHeuristicOp op) -> std::string {
3703 llvm_unreachable("heuristic unrolling does not generate a loop");
3704 })
3705 .Case([&](FuseOp op) -> std::string {
3706 unsigned opnum = generator->getOperandNumber();
3707 // The position of the first loop to be fused is the same position
3708 // as the resulting fused loop
3709 if (op.getFirst().has_value() && opnum != op.getFirst().value())
3710 return "canonloop_fuse";
3711 else
3712 return "fused";
3713 })
3714 .Case([&](TileOp op) -> std::string {
3715 auto [generateesFirst, generateesCount] =
3716 op.getGenerateesODSOperandIndexAndLength();
3717 unsigned firstGrid = generateesFirst;
3718 unsigned firstIntratile = generateesFirst + generateesCount / 2;
3719 unsigned end = generateesFirst + generateesCount;
3720 unsigned opnum = generator->getOperandNumber();
3721 // In the OpenMP apply and looprange clauses, indices are 1-based
3722 if (firstGrid <= opnum && opnum < firstIntratile) {
3723 unsigned gridnum = opnum - firstGrid + 1;
3724 return ("grid" + Twine(gridnum)).str();
3725 }
3726 if (firstIntratile <= opnum && opnum < end) {
3727 unsigned intratilenum = opnum - firstIntratile + 1;
3728 return ("intratile" + Twine(intratilenum)).str();
3729 }
3730 llvm_unreachable("Unexpected generatee argument");
3731 })
3732 .DefaultUnreachable("TODO: Custom name for this operation");
3733 }
3734
3735 setNameFn(result, cliName);
3736}
3737
3738LogicalResult NewCliOp::verify() {
3739 Value cli = getResult();
3740
3741 assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3742 "Unexpected type of cli");
3743
3744 // Check that the CLI is used in at most generator and one consumer
3745 OpOperand *gen = nullptr;
3746 OpOperand *cons = nullptr;
3747 for (mlir::OpOperand &use : cli.getUses()) {
3748 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3749
3750 unsigned opnum = use.getOperandNumber();
3751 if (op.isGeneratee(opnum)) {
3752 if (gen) {
3753 InFlightDiagnostic error =
3754 emitOpError("CLI must have at most one generator");
3755 error.attachNote(gen->getOwner()->getLoc())
3756 .append("first generator here:");
3757 error.attachNote(use.getOwner()->getLoc())
3758 .append("second generator here:");
3759 return error;
3760 }
3761
3762 gen = &use;
3763 } else if (op.isApplyee(opnum)) {
3764 if (cons) {
3765 InFlightDiagnostic error =
3766 emitOpError("CLI must have at most one consumer");
3767 error.attachNote(cons->getOwner()->getLoc())
3768 .append("first consumer here:")
3769 .appendOp(*cons->getOwner(),
3770 OpPrintingFlags().printGenericOpForm());
3771 error.attachNote(use.getOwner()->getLoc())
3772 .append("second consumer here:")
3773 .appendOp(*use.getOwner(), OpPrintingFlags().printGenericOpForm());
3774 return error;
3775 }
3776
3777 cons = &use;
3778 } else {
3779 llvm_unreachable("Unexpected operand for a CLI");
3780 }
3781 }
3782
3783 // If the CLI is source of a transformation, it must have a generator
3784 if (cons && !gen) {
3785 InFlightDiagnostic error = emitOpError("CLI has no generator");
3786 error.attachNote(cons->getOwner()->getLoc())
3787 .append("see consumer here: ")
3788 .appendOp(*cons->getOwner(), OpPrintingFlags().printGenericOpForm());
3789 return error;
3790 }
3791
3792 return success();
3793}
3794
3795void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3796 Value tripCount) {
3797 odsState.addOperands(tripCount);
3798 odsState.addOperands(Value());
3799 (void)odsState.addRegion();
3800}
3801
3802void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3803 Value tripCount, ::mlir::Value cli) {
3804 odsState.addOperands(tripCount);
3805 odsState.addOperands(cli);
3806 (void)odsState.addRegion();
3807}
3808
3809void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) {
3810 setNameFn(&getRegion().front(), "body_entry");
3811}
3812
3813void CanonicalLoopOp::getAsmBlockArgumentNames(Region &region,
3814 OpAsmSetValueNameFn setNameFn) {
3815 std::string ivName = generateLoopNestingName("iv", *this);
3816 setNameFn(region.getArgument(0), ivName);
3817}
3818
3819void CanonicalLoopOp::print(OpAsmPrinter &p) {
3820 if (getCli())
3821 p << '(' << getCli() << ')';
3822 p << ' ' << getInductionVar() << " : " << getInductionVar().getType()
3823 << " in range(" << getTripCount() << ") ";
3824
3825 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3826 /*printBlockTerminators=*/true);
3827
3828 p.printOptionalAttrDict((*this)->getAttrs());
3829}
3830
3831mlir::ParseResult CanonicalLoopOp::parse(::mlir::OpAsmParser &parser,
3833 CanonicalLoopInfoType cliType =
3834 CanonicalLoopInfoType::get(parser.getContext());
3835
3836 // Parse (optional) omp.cli identifier
3838 SmallVector<mlir::Value, 1> cliOperand;
3839 if (!parser.parseOptionalLParen()) {
3840 if (parser.parseOperand(cli) ||
3841 parser.resolveOperand(cli, cliType, cliOperand) || parser.parseRParen())
3842 return failure();
3843 }
3844
3845 // We derive the type of tripCount from inductionVariable. MLIR requires the
3846 // type of tripCount to be known when calling resolveOperand so we have parse
3847 // the type before processing the inductionVariable.
3848 OpAsmParser::Argument inductionVariable;
3850 if (parser.parseArgument(inductionVariable, /*allowType*/ true) ||
3851 parser.parseKeyword("in") || parser.parseKeyword("range") ||
3852 parser.parseLParen() || parser.parseOperand(tripcount) ||
3853 parser.parseRParen() ||
3854 parser.resolveOperand(tripcount, inductionVariable.type, result.operands))
3855 return failure();
3856
3857 // Parse the loop body.
3858 Region *region = result.addRegion();
3859 if (parser.parseRegion(*region, {inductionVariable}))
3860 return failure();
3861
3862 // We parsed the cli operand forst, but because it is optional, it must be
3863 // last in the operand list.
3864 result.operands.append(cliOperand);
3865
3866 // Parse the optional attribute list.
3867 if (parser.parseOptionalAttrDict(result.attributes))
3868 return failure();
3869
3870 return mlir::success();
3871}
3872
3873LogicalResult CanonicalLoopOp::verify() {
3874 // The region's entry must accept the induction variable
3875 // It can also be empty if just created
3876 if (!getRegion().empty()) {
3877 Region &region = getRegion();
3878 if (region.getNumArguments() != 1)
3879 return emitOpError(
3880 "Canonical loop region must have exactly one argument");
3881
3882 if (getInductionVar().getType() != getTripCount().getType())
3883 return emitOpError(
3884 "Region argument must be the same type as the trip count");
3885 }
3886
3887 return success();
3888}
3889
3890Value CanonicalLoopOp::getInductionVar() { return getRegion().getArgument(0); }
3891
3892std::pair<unsigned, unsigned>
3893CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
3894 // No applyees
3895 return {0, 0};
3896}
3897
3898std::pair<unsigned, unsigned>
3899CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
3900 return getODSOperandIndexAndLength(odsIndex_cli);
3901}
3902
3903//===----------------------------------------------------------------------===//
3904// UnrollHeuristicOp
3905//===----------------------------------------------------------------------===//
3906
3907void UnrollHeuristicOp::build(::mlir::OpBuilder &odsBuilder,
3908 ::mlir::OperationState &odsState,
3909 ::mlir::Value cli) {
3910 odsState.addOperands(cli);
3911}
3912
3913void UnrollHeuristicOp::print(OpAsmPrinter &p) {
3914 p << '(' << getApplyee() << ')';
3915
3916 p.printOptionalAttrDict((*this)->getAttrs());
3917}
3918
3919mlir::ParseResult UnrollHeuristicOp::parse(::mlir::OpAsmParser &parser,
3921 auto cliType = CanonicalLoopInfoType::get(parser.getContext());
3922
3923 if (parser.parseLParen())
3924 return failure();
3925
3927 if (parser.parseOperand(applyee) ||
3928 parser.resolveOperand(applyee, cliType, result.operands))
3929 return failure();
3930
3931 if (parser.parseRParen())
3932 return failure();
3933
3934 // Optional output loop (full unrolling has none)
3935 if (!parser.parseOptionalArrow()) {
3936 if (parser.parseLParen() || parser.parseRParen())
3937 return failure();
3938 }
3939
3940 // Parse the optional attribute list.
3941 if (parser.parseOptionalAttrDict(result.attributes))
3942 return failure();
3943
3944 return mlir::success();
3945}
3946
3947std::pair<unsigned, unsigned>
3948UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
3949 return getODSOperandIndexAndLength(odsIndex_applyee);
3950}
3951
3952std::pair<unsigned, unsigned>
3953UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3954 return {0, 0};
3955}
3956
3957//===----------------------------------------------------------------------===//
3958// TileOp
3959//===----------------------------------------------------------------------===//
3960
3961static void printLoopTransformClis(OpAsmPrinter &p, TileOp op,
3962 OperandRange generatees,
3963 OperandRange applyees) {
3964 if (!generatees.empty())
3965 p << '(' << llvm::interleaved(generatees) << ')';
3966
3967 if (!applyees.empty())
3968 p << " <- (" << llvm::interleaved(applyees) << ')';
3969}
3970
3971static ParseResult parseLoopTransformClis(
3972 OpAsmParser &parser,
3975 if (parser.parseOptionalLess()) {
3976 // Syntax 1: generatees present
3977
3978 if (parser.parseOperandList(generateesOperands,
3980 return failure();
3981
3982 if (parser.parseLess())
3983 return failure();
3984 } else {
3985 // Syntax 2: generatees omitted
3986 }
3987
3988 // Parse `<-` (`<` has already been parsed)
3989 if (parser.parseMinus())
3990 return failure();
3991
3992 if (parser.parseOperandList(applyeesOperands,
3994 return failure();
3995
3996 return success();
3997}
3998
3999/// Check properties of the loop nest consisting of the transformation's
4000/// applyees:
4001/// 1. They are nested inside each other
4002/// 2. They are perfectly nested
4003/// (no code with side-effects in-between the loops)
4004/// 3. They are rectangular
4005/// (loop bounds are invariant in respect to the outer loops)
4006///
4007/// TODO: Generalize for LoopTransformationInterface.
4008static LogicalResult checkApplyeesNesting(TileOp op) {
4009 // Collect the loops from the nest
4010 bool isOnlyCanonLoops = true;
4012 for (Value applyee : op.getApplyees()) {
4013 auto [create, gen, cons] = decodeCli(applyee);
4014
4015 if (!gen)
4016 return op.emitOpError() << "applyee CLI has no generator";
4017
4018 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4019 canonLoops.push_back(loop);
4020 if (!loop)
4021 isOnlyCanonLoops = false;
4022 }
4023
4024 // FIXME: We currently can only verify non-rectangularity and perfect nest of
4025 // omp.canonical_loop.
4026 if (!isOnlyCanonLoops)
4027 return success();
4028
4029 DenseSet<Value> parentIVs;
4030 for (auto i : llvm::seq<int>(1, canonLoops.size())) {
4031 auto parentLoop = canonLoops[i - 1];
4032 auto loop = canonLoops[i];
4033
4034 if (parentLoop.getOperation() != loop.getOperation()->getParentOp())
4035 return op.emitOpError()
4036 << "tiled loop nest must be nested within each other";
4037
4038 parentIVs.insert(parentLoop.getInductionVar());
4039
4040 // Canonical loop must be perfectly nested, i.e. the body of the parent must
4041 // only contain the omp.canonical_loop of the nested loops, and
4042 // omp.terminator
4043 bool isPerfectlyNested = [&]() {
4044 auto &parentBody = parentLoop.getRegion();
4045 if (!parentBody.hasOneBlock())
4046 return false;
4047 auto &parentBlock = parentBody.getBlocks().front();
4048
4049 auto nestedLoopIt = parentBlock.begin();
4050 if (nestedLoopIt == parentBlock.end() ||
4051 (&*nestedLoopIt != loop.getOperation()))
4052 return false;
4053
4054 auto termIt = std::next(nestedLoopIt);
4055 if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
4056 return false;
4057
4058 if (std::next(termIt) != parentBlock.end())
4059 return false;
4060
4061 return true;
4062 }();
4063 if (!isPerfectlyNested)
4064 return op.emitOpError() << "tiled loop nest must be perfectly nested";
4065
4066 if (parentIVs.contains(loop.getTripCount()))
4067 return op.emitOpError() << "tiled loop nest must be rectangular";
4068 }
4069
4070 // TODO: The tile sizes must be computed before the loop, but checking this
4071 // requires dominance analysis. For instance:
4072 //
4073 // %canonloop = omp.new_cli
4074 // omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
4075 // // write to %x
4076 // omp.terminator
4077 // }
4078 // %ts = llvm.load %x
4079 // omp.tile <- (%canonloop) sizes(%ts : i32)
4080
4081 return success();
4082}
4083
4084LogicalResult TileOp::verify() {
4085 if (getApplyees().empty())
4086 return emitOpError() << "must apply to at least one loop";
4087
4088 if (getSizes().size() != getApplyees().size())
4089 return emitOpError() << "there must be one tile size for each applyee";
4090
4091 if (!getGeneratees().empty() &&
4092 2 * getSizes().size() != getGeneratees().size())
4093 return emitOpError()
4094 << "expecting two times the number of generatees than applyees";
4095
4096 return checkApplyeesNesting(*this);
4097}
4098
4099std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
4100 return getODSOperandIndexAndLength(odsIndex_applyees);
4101}
4102
4103std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
4104 return getODSOperandIndexAndLength(odsIndex_generatees);
4105}
4106
4107//===----------------------------------------------------------------------===//
4108// FuseOp
4109//===----------------------------------------------------------------------===//
4110
4111static void printLoopTransformClis(OpAsmPrinter &p, FuseOp op,
4112 OperandRange generatees,
4113 OperandRange applyees) {
4114 if (!generatees.empty())
4115 p << '(' << llvm::interleaved(generatees) << ')';
4116
4117 if (!applyees.empty())
4118 p << " <- (" << llvm::interleaved(applyees) << ')';
4119}
4120
4121LogicalResult FuseOp::verify() {
4122 if (getApplyees().size() < 2)
4123 return emitOpError() << "must apply to at least two loops";
4124
4125 if (getFirst().has_value() && getCount().has_value()) {
4126 int64_t first = getFirst().value();
4127 int64_t count = getCount().value();
4128 if ((unsigned)(first + count - 1) > getApplyees().size())
4129 return emitOpError() << "the numbers of applyees must be at least first "
4130 "minus one plus count attributes";
4131 if (!getGeneratees().empty() &&
4132 getGeneratees().size() != getApplyees().size() + 1 - count)
4133 return emitOpError() << "the number of generatees must be the number of "
4134 "aplyees plus one minus count";
4135
4136 } else {
4137 if (!getGeneratees().empty() && getGeneratees().size() != 1)
4138 return emitOpError()
4139 << "in a complete fuse the number of generatees must be exactly 1";
4140 }
4141 for (auto &&applyee : getApplyees()) {
4142 auto [create, gen, cons] = decodeCli(applyee);
4143
4144 if (!gen)
4145 return emitOpError() << "applyee CLI has no generator";
4146 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4147 if (!loop)
4148 return emitOpError()
4149 << "currently only supports omp.canonical_loop as applyee";
4150 }
4151 return success();
4152}
4153std::pair<unsigned, unsigned> FuseOp::getApplyeesODSOperandIndexAndLength() {
4154 return getODSOperandIndexAndLength(odsIndex_applyees);
4155}
4156
4157std::pair<unsigned, unsigned> FuseOp::getGenerateesODSOperandIndexAndLength() {
4158 return getODSOperandIndexAndLength(odsIndex_generatees);
4159}
4160
4161//===----------------------------------------------------------------------===//
4162// Critical construct (2.17.1)
4163//===----------------------------------------------------------------------===//
4164
4165void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
4166 const CriticalDeclareOperands &clauses) {
4167 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
4168}
4169
4170LogicalResult CriticalDeclareOp::verify() {
4171 return verifySynchronizationHint(*this, getHint());
4172}
4173
4174LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
4175 if (getNameAttr()) {
4176 SymbolRefAttr symbolRef = getNameAttr();
4177 auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
4178 *this, symbolRef);
4179 if (!decl) {
4180 return emitOpError() << "expected symbol reference " << symbolRef
4181 << " to point to a critical declaration";
4182 }
4183 }
4184
4185 return success();
4186}
4187
4188//===----------------------------------------------------------------------===//
4189// Ordered construct
4190//===----------------------------------------------------------------------===//
4191
4192static LogicalResult verifyOrderedParent(Operation &op) {
4193 bool hasRegion = op.getNumRegions() > 0;
4194 auto loopOp = op.getParentOfType<LoopNestOp>();
4195 if (!loopOp) {
4196 if (hasRegion)
4197 return success();
4198
4199 // TODO: Consider if this needs to be the case only for the standalone
4200 // variant of the ordered construct.
4201 return op.emitOpError() << "must be nested inside of a loop";
4202 }
4203
4204 Operation *wrapper = loopOp->getParentOp();
4205 if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
4206 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
4207 if (!orderedAttr)
4208 return op.emitOpError() << "the enclosing worksharing-loop region must "
4209 "have an ordered clause";
4210
4211 if (hasRegion && orderedAttr.getInt() != 0)
4212 return op.emitOpError() << "the enclosing loop's ordered clause must not "
4213 "have a parameter present";
4214
4215 if (!hasRegion && orderedAttr.getInt() == 0)
4216 return op.emitOpError() << "the enclosing loop's ordered clause must "
4217 "have a parameter present";
4218 } else if (!isa<SimdOp>(wrapper)) {
4219 return op.emitOpError() << "must be nested inside of a worksharing, simd "
4220 "or worksharing simd loop";
4221 }
4222 return success();
4223}
4224
4225void OrderedOp::build(OpBuilder &builder, OperationState &state,
4226 const OrderedOperands &clauses) {
4227 OrderedOp::build(builder, state, clauses.doacrossDependType,
4228 clauses.doacrossNumLoops, clauses.doacrossDependVars);
4229}
4230
4231LogicalResult OrderedOp::verify() {
4232 if (failed(verifyOrderedParent(**this)))
4233 return failure();
4234
4235 auto wrapper = (*this)->getParentOfType<WsloopOp>();
4236 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
4237 return emitOpError() << "number of variables in depend clause does not "
4238 << "match number of iteration variables in the "
4239 << "doacross loop";
4240
4241 return success();
4242}
4243
4244void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
4245 const OrderedRegionOperands &clauses) {
4246 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
4247}
4248
4249LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
4250
4251//===----------------------------------------------------------------------===//
4252// TaskwaitOp
4253//===----------------------------------------------------------------------===//
4254
4255void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
4256 const TaskwaitOperands &clauses) {
4257 // TODO Store clauses in op: dependKinds, dependVars, nowait.
4258 TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
4259 /*depend_vars=*/{}, /*depend_iterated_kinds=*/nullptr,
4260 /*depend_iterated=*/{}, /*nowait=*/nullptr);
4261}
4262
4263//===----------------------------------------------------------------------===//
4264// Verifier for AtomicReadOp
4265//===----------------------------------------------------------------------===//
4266
4267LogicalResult AtomicReadOp::verify() {
4268 if (verifyCommon().failed())
4269 return mlir::failure();
4270
4271 if (auto mo = getMemoryOrder()) {
4272 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4273 *mo == ClauseMemoryOrderKind::Release) {
4274 return emitError(
4275 "memory-order must not be acq_rel or release for atomic reads");
4276 }
4277 }
4278 return verifySynchronizationHint(*this, getHint());
4279}
4280
4281//===----------------------------------------------------------------------===//
4282// Verifier for AtomicWriteOp
4283//===----------------------------------------------------------------------===//
4284
4285LogicalResult AtomicWriteOp::verify() {
4286 if (verifyCommon().failed())
4287 return mlir::failure();
4288
4289 if (auto mo = getMemoryOrder()) {
4290 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4291 *mo == ClauseMemoryOrderKind::Acquire) {
4292 return emitError(
4293 "memory-order must not be acq_rel or acquire for atomic writes");
4294 }
4295 }
4296 return verifySynchronizationHint(*this, getHint());
4297}
4298
4299//===----------------------------------------------------------------------===//
4300// Verifier for AtomicUpdateOp
4301//===----------------------------------------------------------------------===//
4302
4303LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4304 PatternRewriter &rewriter) {
4305 if (op.isNoOp()) {
4306 rewriter.eraseOp(op);
4307 return success();
4308 }
4309 if (Value writeVal = op.getWriteOpVal()) {
4310 rewriter.replaceOpWithNewOp<AtomicWriteOp>(
4311 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
4312 return success();
4313 }
4314 return failure();
4315}
4316
4317LogicalResult AtomicUpdateOp::verify() {
4318 if (verifyCommon().failed())
4319 return mlir::failure();
4320
4321 if (auto mo = getMemoryOrder()) {
4322 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4323 *mo == ClauseMemoryOrderKind::Acquire) {
4324 return emitError(
4325 "memory-order must not be acq_rel or acquire for atomic updates");
4326 }
4327 }
4328
4329 return verifySynchronizationHint(*this, getHint());
4330}
4331
4332LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
4333
4334//===----------------------------------------------------------------------===//
4335// Verifier for AtomicCaptureOp
4336//===----------------------------------------------------------------------===//
4337
4338AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4339 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4340 return op;
4341 return dyn_cast<AtomicReadOp>(getSecondOp());
4342}
4343
4344AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4345 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4346 return op;
4347 return dyn_cast<AtomicWriteOp>(getSecondOp());
4348}
4349
4350AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4351 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4352 return op;
4353 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4354}
4355
4356LogicalResult AtomicCaptureOp::verify() {
4357 return verifySynchronizationHint(*this, getHint());
4358}
4359
4360LogicalResult AtomicCaptureOp::verifyRegions() {
4361 if (verifyRegionsCommon().failed())
4362 return mlir::failure();
4363
4364 if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
4365 return emitOpError(
4366 "operations inside capture region must not have hint clause");
4367
4368 if (getFirstOp()->getAttr("memory_order") ||
4369 getSecondOp()->getAttr("memory_order"))
4370 return emitOpError(
4371 "operations inside capture region must not have memory_order clause");
4372 return success();
4373}
4374
4375//===----------------------------------------------------------------------===//
4376// CancelOp
4377//===----------------------------------------------------------------------===//
4378
4379void CancelOp::build(OpBuilder &builder, OperationState &state,
4380 const CancelOperands &clauses) {
4381 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4382}
4383
4385 Operation *parent = thisOp->getParentOp();
4386 while (parent) {
4387 if (parent->getDialect() == thisOp->getDialect())
4388 return parent;
4389 parent = parent->getParentOp();
4390 }
4391 return nullptr;
4392}
4393
4394LogicalResult CancelOp::verify() {
4395 ClauseCancellationConstructType cct = getCancelDirective();
4396 // The next OpenMP operation in the chain of parents
4397 Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4398 if (!structuralParent)
4399 return emitOpError() << "Orphaned cancel construct";
4400
4401 if ((cct == ClauseCancellationConstructType::Parallel) &&
4402 !mlir::isa<ParallelOp>(structuralParent)) {
4403 return emitOpError() << "cancel parallel must appear "
4404 << "inside a parallel region";
4405 }
4406 if (cct == ClauseCancellationConstructType::Loop) {
4407 // structural parent will be omp.loop_nest, directly nested inside
4408 // omp.wsloop
4409 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
4410
4411 if (!wsloopOp) {
4412 return emitOpError()
4413 << "cancel loop must appear inside a worksharing-loop region";
4414 }
4415 if (wsloopOp.getNowaitAttr()) {
4416 return emitError() << "A worksharing construct that is canceled "
4417 << "must not have a nowait clause";
4418 }
4419 if (wsloopOp.getOrderedAttr()) {
4420 return emitError() << "A worksharing construct that is canceled "
4421 << "must not have an ordered clause";
4422 }
4423
4424 } else if (cct == ClauseCancellationConstructType::Sections) {
4425 // structural parent will be an omp.section, directly nested inside
4426 // omp.sections
4427 auto sectionsOp =
4428 mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
4429 if (!sectionsOp) {
4430 return emitOpError() << "cancel sections must appear "
4431 << "inside a sections region";
4432 }
4433 if (sectionsOp.getNowait()) {
4434 return emitError() << "A sections construct that is canceled "
4435 << "must not have a nowait clause";
4436 }
4437 }
4438 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4439 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4440 !mlir::isa<omp::TaskloopWrapperOp>(structuralParent->getParentOp()))) {
4441 return emitOpError() << "cancel taskgroup must appear "
4442 << "inside a task region";
4443 }
4444 return success();
4445}
4446
4447//===----------------------------------------------------------------------===//
4448// CancellationPointOp
4449//===----------------------------------------------------------------------===//
4450
4451void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
4452 const CancellationPointOperands &clauses) {
4453 CancellationPointOp::build(builder, state, clauses.cancelDirective);
4454}
4455
4456LogicalResult CancellationPointOp::verify() {
4457 ClauseCancellationConstructType cct = getCancelDirective();
4458 // The next OpenMP operation in the chain of parents
4459 Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4460 if (!structuralParent)
4461 return emitOpError() << "Orphaned cancellation point";
4462
4463 if ((cct == ClauseCancellationConstructType::Parallel) &&
4464 !mlir::isa<ParallelOp>(structuralParent)) {
4465 return emitOpError() << "cancellation point parallel must appear "
4466 << "inside a parallel region";
4467 }
4468 // Strucutal parent here will be an omp.loop_nest. Get the parent of that to
4469 // find the wsloop
4470 if ((cct == ClauseCancellationConstructType::Loop) &&
4471 !mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
4472 return emitOpError() << "cancellation point loop must appear "
4473 << "inside a worksharing-loop region";
4474 }
4475 if ((cct == ClauseCancellationConstructType::Sections) &&
4476 !mlir::isa<omp::SectionOp>(structuralParent)) {
4477 return emitOpError() << "cancellation point sections must appear "
4478 << "inside a sections region";
4479 }
4480 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4481 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4482 !mlir::isa<omp::TaskloopWrapperOp>(structuralParent->getParentOp()))) {
4483 return emitOpError() << "cancellation point taskgroup must appear "
4484 << "inside a task region";
4485 }
4486 return success();
4487}
4488
4489//===----------------------------------------------------------------------===//
4490// MapBoundsOp
4491//===----------------------------------------------------------------------===//
4492
4493LogicalResult MapBoundsOp::verify() {
4494 auto extent = getExtent();
4495 auto upperbound = getUpperBound();
4496 if (!extent && !upperbound)
4497 return emitError("expected extent or upperbound.");
4498 return success();
4499}
4500
4501void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4502 TypeRange /*result_types*/, StringAttr symName,
4503 TypeAttr type) {
4504 PrivateClauseOp::build(
4505 odsBuilder, odsState, symName, type,
4506 DataSharingClauseTypeAttr::get(odsBuilder.getContext(),
4507 DataSharingClauseType::Private));
4508}
4509
4510LogicalResult PrivateClauseOp::verifyRegions() {
4511 Type argType = getArgType();
4512 auto verifyTerminator = [&](Operation *terminator,
4513 bool yieldsValue) -> LogicalResult {
4514 if (!terminator->getBlock()->getSuccessors().empty())
4515 return success();
4516
4517 if (!llvm::isa<YieldOp>(terminator))
4518 return mlir::emitError(terminator->getLoc())
4519 << "expected exit block terminator to be an `omp.yield` op.";
4520
4521 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4522 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4523
4524 if (!yieldsValue) {
4525 if (yieldedTypes.empty())
4526 return success();
4527
4528 return mlir::emitError(terminator->getLoc())
4529 << "Did not expect any values to be yielded.";
4530 }
4531
4532 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4533 return success();
4534
4535 auto error = mlir::emitError(yieldOp.getLoc())
4536 << "Invalid yielded value. Expected type: " << argType
4537 << ", got: ";
4538
4539 if (yieldedTypes.empty())
4540 error << "None";
4541 else
4542 error << yieldedTypes;
4543
4544 return error;
4545 };
4546
4547 auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
4548 StringRef regionName,
4549 bool yieldsValue) -> LogicalResult {
4550 assert(!region.empty());
4551
4552 if (region.getNumArguments() != expectedNumArgs)
4553 return mlir::emitError(region.getLoc())
4554 << "`" << regionName << "`: "
4555 << "expected " << expectedNumArgs
4556 << " region arguments, got: " << region.getNumArguments();
4557
4558 for (Block &block : region) {
4559 // MLIR will verify the absence of the terminator for us.
4560 if (!block.mightHaveTerminator())
4561 continue;
4562
4563 if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
4564 return failure();
4565 }
4566
4567 return success();
4568 };
4569
4570 // Ensure all of the region arguments have the same type
4571 for (Region *region : getRegions())
4572 for (Type ty : region->getArgumentTypes())
4573 if (ty != argType)
4574 return emitError() << "Region argument type mismatch: got " << ty
4575 << " expected " << argType << ".";
4576
4577 mlir::Region &initRegion = getInitRegion();
4578 if (!initRegion.empty() &&
4579 failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init",
4580 /*yieldsValue=*/true)))
4581 return failure();
4582
4583 DataSharingClauseType dsType = getDataSharingType();
4584
4585 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4586 return emitError("`private` clauses do not require a `copy` region.");
4587
4588 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4589 return emitError(
4590 "`firstprivate` clauses require at least a `copy` region.");
4591
4592 if (dsType == DataSharingClauseType::FirstPrivate &&
4593 failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
4594 /*yieldsValue=*/true)))
4595 return failure();
4596
4597 if (!getDeallocRegion().empty() &&
4598 failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
4599 /*yieldsValue=*/false)))
4600 return failure();
4601
4602 return success();
4603}
4604
4605//===----------------------------------------------------------------------===//
4606// Spec 5.2: Masked construct (10.5)
4607//===----------------------------------------------------------------------===//
4608
4609void MaskedOp::build(OpBuilder &builder, OperationState &state,
4610 const MaskedOperands &clauses) {
4611 MaskedOp::build(builder, state, clauses.filteredThreadId);
4612}
4613
4614//===----------------------------------------------------------------------===//
4615// Spec 5.2: Scan construct (5.6)
4616//===----------------------------------------------------------------------===//
4617
4618void ScanOp::build(OpBuilder &builder, OperationState &state,
4619 const ScanOperands &clauses) {
4620 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4621}
4622
4623LogicalResult ScanOp::verify() {
4624 if (hasExclusiveVars() == hasInclusiveVars())
4625 return emitError(
4626 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4627 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4628 if (parentWsLoopOp.getReductionModAttr() &&
4629 parentWsLoopOp.getReductionModAttr().getValue() ==
4630 ReductionModifier::inscan)
4631 return success();
4632 }
4633 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4634 if (parentSimdOp.getReductionModAttr() &&
4635 parentSimdOp.getReductionModAttr().getValue() ==
4636 ReductionModifier::inscan)
4637 return success();
4638 }
4639 return emitError("SCAN directive needs to be enclosed within a parent "
4640 "worksharing loop construct or SIMD construct with INSCAN "
4641 "reduction modifier");
4642}
4643
4644/// Verifies align clause in allocate directive
4645
4646LogicalResult AllocateDirOp::verify() {
4647 std::optional<uint64_t> align = this->getAlign();
4648
4649 if (align.has_value()) {
4650 if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
4651 return emitError() << "ALIGN value : " << align.value()
4652 << " must be power of 2";
4653 }
4654
4655 return success();
4656}
4657
4658//===----------------------------------------------------------------------===//
4659// TargetAllocMemOp
4660//===----------------------------------------------------------------------===//
4661
4662mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
4663 return getInTypeAttr().getValue();
4664}
4665
4666/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype,
4667/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
4668/// attr-dict-without-keyword
4669static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser,
4671 auto &builder = parser.getBuilder();
4672 bool hasOperands = false;
4673 std::int32_t typeparamsSize = 0;
4674
4675 // Parse device number as a new operand
4677 mlir::Type deviceType;
4678 if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
4679 return mlir::failure();
4680 if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
4681 return mlir::failure();
4682 if (parser.parseComma())
4683 return mlir::failure();
4684
4685 mlir::Type intype;
4686 if (parser.parseType(intype))
4687 return mlir::failure();
4688 result.addAttribute("in_type", mlir::TypeAttr::get(intype));
4691 if (!parser.parseOptionalLParen()) {
4692 // parse the LEN params of the derived type. (<params> : <types>)
4694 parser.parseColonTypeList(typeVec) || parser.parseRParen())
4695 return mlir::failure();
4696 typeparamsSize = operands.size();
4697 hasOperands = true;
4698 }
4699 std::int32_t shapeSize = 0;
4700 if (!parser.parseOptionalComma()) {
4701 // parse size to scale by, vector of n dimensions of type index
4703 return mlir::failure();
4704 shapeSize = operands.size() - typeparamsSize;
4705 auto idxTy = builder.getIndexType();
4706 for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
4707 typeVec.push_back(idxTy);
4708 hasOperands = true;
4709 }
4710 if (hasOperands &&
4711 parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
4712 result.operands))
4713 return mlir::failure();
4714
4715 mlir::Type restype = builder.getIntegerType(64);
4716 if (!restype) {
4717 parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
4718 return mlir::failure();
4719 }
4720 llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize};
4721 result.addAttribute("operandSegmentSizes",
4722 builder.getDenseI32ArrayAttr(segmentSizes));
4723 if (parser.parseOptionalAttrDict(result.attributes) ||
4724 parser.addTypeToList(restype, result.types))
4725 return mlir::failure();
4726 return mlir::success();
4727}
4728
4729mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser,
4731 return parseTargetAllocMemOp(parser, result);
4732}
4733
4734void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) {
4735 p << " ";
4737 p << " : ";
4738 p << getDevice().getType();
4739 p << ", ";
4740 p << getInType();
4741 if (!getTypeparams().empty()) {
4742 p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
4743 }
4744 for (auto sh : getShape()) {
4745 p << ", ";
4746 p.printOperand(sh);
4747 }
4748 p.printOptionalAttrDict((*this)->getAttrs(),
4749 {"in_type", "operandSegmentSizes"});
4750}
4751
4752llvm::LogicalResult omp::TargetAllocMemOp::verify() {
4753 mlir::Type outType = getType();
4754 if (!mlir::dyn_cast<IntegerType>(outType))
4755 return emitOpError("must be a integer type");
4756 return mlir::success();
4757}
4758
4759//===----------------------------------------------------------------------===//
4760// WorkdistributeOp
4761//===----------------------------------------------------------------------===//
4762
4763LogicalResult WorkdistributeOp::verify() {
4764 // Check that region exists and is not empty
4765 Region &region = getRegion();
4766 if (region.empty())
4767 return emitOpError("region cannot be empty");
4768 // Verify single entry point.
4769 Block &entryBlock = region.front();
4770 if (entryBlock.empty())
4771 return emitOpError("region must contain a structured block");
4772 // Verify single exit point.
4773 bool hasTerminator = false;
4774 for (Block &block : region) {
4775 if (isa<TerminatorOp>(block.back())) {
4776 if (hasTerminator) {
4777 return emitOpError("region must have exactly one terminator");
4778 }
4779 hasTerminator = true;
4780 }
4781 }
4782 if (!hasTerminator) {
4783 return emitOpError("region must be terminated with omp.terminator");
4784 }
4785 auto walkResult = region.walk([&](Operation *op) -> WalkResult {
4786 // No implicit barrier at end
4787 if (isa<BarrierOp>(op)) {
4788 return emitOpError(
4789 "explicit barriers are not allowed in workdistribute region");
4790 }
4791 // Check for invalid nested constructs
4792 if (isa<ParallelOp>(op)) {
4793 return emitOpError(
4794 "nested parallel constructs not allowed in workdistribute");
4795 }
4796 if (isa<TeamsOp>(op)) {
4797 return emitOpError(
4798 "nested teams constructs not allowed in workdistribute");
4799 }
4800 return WalkResult::advance();
4801 });
4802 if (walkResult.wasInterrupted())
4803 return failure();
4804
4805 Operation *parentOp = (*this)->getParentOp();
4806 if (!llvm::dyn_cast<TeamsOp>(parentOp))
4807 return emitOpError("workdistribute must be nested under teams");
4808 return success();
4809}
4810
4811//===----------------------------------------------------------------------===//
4812// Declare simd [7.7]
4813//===----------------------------------------------------------------------===//
4814
4815LogicalResult DeclareSimdOp::verify() {
4816 // Must be nested inside a function-like op
4817 auto func =
4818 dyn_cast_if_present<mlir::FunctionOpInterface>((*this)->getParentOp());
4819 if (!func)
4820 return emitOpError() << "must be nested inside a function";
4821
4822 if (getInbranch() && getNotinbranch())
4823 return emitOpError("cannot have both 'inbranch' and 'notinbranch'");
4824
4825 if (failed(verifyLinearModifiers(*this, getLinearModifiers(), getLinearVars(),
4826 /*isDeclareSimd=*/true)))
4827 return failure();
4828
4829 return verifyAlignedClause(*this, getAlignments(), getAlignedVars());
4830}
4831
4832void DeclareSimdOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4833 const DeclareSimdOperands &clauses) {
4834 MLIRContext *ctx = odsBuilder.getContext();
4835 DeclareSimdOp::build(odsBuilder, odsState, clauses.alignedVars,
4836 makeArrayAttr(ctx, clauses.alignments), clauses.inbranch,
4837 clauses.linearVars, clauses.linearStepVars,
4838 clauses.linearVarTypes, clauses.linearModifiers,
4839 clauses.notinbranch, clauses.simdlen,
4840 clauses.uniformVars);
4841}
4842
4843//===----------------------------------------------------------------------===//
4844// Parser and printer for Uniform Clause
4845//===----------------------------------------------------------------------===//
4846
4847/// uniform ::= `uniform` `(` uniform-list `)`
4848/// uniform-list := uniform-val (`,` uniform-val)*
4849/// uniform-val := ssa-id `:` type
4850static ParseResult
4853 SmallVectorImpl<Type> &uniformTypes) {
4854 return parser.parseCommaSeparatedList([&]() -> mlir::ParseResult {
4855 if (parser.parseOperand(uniformVars.emplace_back()) ||
4856 parser.parseColonType(uniformTypes.emplace_back()))
4857 return mlir::failure();
4858 return mlir::success();
4859 });
4860}
4861
4862/// Print Uniform Clauses
4864 ValueRange uniformVars, TypeRange uniformTypes) {
4865 for (unsigned i = 0; i < uniformVars.size(); ++i) {
4866 if (i != 0)
4867 p << ", ";
4868 p << uniformVars[i] << " : " << uniformTypes[i];
4869 }
4870}
4871
4872//===----------------------------------------------------------------------===//
4873// Parser and printer for Affinity Clause
4874//===----------------------------------------------------------------------===//
4875
4876static ParseResult parseAffinityClause(
4877 OpAsmParser &parser,
4880 SmallVectorImpl<Type> &iteratedTypes,
4881 SmallVectorImpl<Type> &affinityVarTypes) {
4882 if (failed(parseSplitIteratedList(
4883 parser, iterated, iteratedTypes, affinityVars, affinityVarTypes,
4884 /*parsePrefix=*/[&]() -> ParseResult { return success(); })))
4885 return failure();
4886 return success();
4887}
4888
4890 ValueRange iterated, ValueRange affinityVars,
4891 TypeRange iteratedTypes,
4892 TypeRange affinityVarTypes) {
4893 auto nop = [&](Value, Type) {};
4894 printSplitIteratedList(p, iterated, iteratedTypes, affinityVars,
4895 affinityVarTypes,
4896 /*plain prefix*/ nop,
4897 /*iterated prefix*/ nop);
4898}
4899
4900//===----------------------------------------------------------------------===//
4901// Parser, printer, and verifier for Iterator modifier
4902//===----------------------------------------------------------------------===//
4903
4904static ParseResult
4909 SmallVectorImpl<Type> &lbTypes,
4910 SmallVectorImpl<Type> &ubTypes,
4911 SmallVectorImpl<Type> &stepTypes) {
4912
4913 llvm::SMLoc ivLoc = parser.getCurrentLocation();
4915
4916 // Parse induction variables: %i : i32, %j : i32
4917 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
4918 OpAsmParser::Argument &arg = ivArgs.emplace_back();
4919 if (parser.parseArgument(arg))
4920 return failure();
4921
4922 // Optional type, default to Index if not provided
4923 if (succeeded(parser.parseOptionalColon())) {
4924 if (parser.parseType(arg.type))
4925 return failure();
4926 } else {
4927 arg.type = parser.getBuilder().getIndexType();
4928 }
4929 return success();
4930 }))
4931 return failure();
4932
4933 // ) = (
4934 if (parser.parseRParen() || parser.parseEqual() || parser.parseLParen())
4935 return failure();
4936
4937 // Parse Ranges: (%lb to %ub step %st, ...)
4938 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
4939 OpAsmParser::UnresolvedOperand lb, ub, st;
4940 if (parser.parseOperand(lb) || parser.parseKeyword("to") ||
4941 parser.parseOperand(ub) || parser.parseKeyword("step") ||
4942 parser.parseOperand(st))
4943 return failure();
4944
4945 lbs.push_back(lb);
4946 ubs.push_back(ub);
4947 steps.push_back(st);
4948 return success();
4949 }))
4950 return failure();
4951
4952 if (parser.parseRParen())
4953 return failure();
4954
4955 if (ivArgs.size() != lbs.size())
4956 return parser.emitError(ivLoc)
4957 << "mismatch: " << ivArgs.size() << " variables but " << lbs.size()
4958 << " ranges";
4959
4960 for (auto &arg : ivArgs) {
4961 lbTypes.push_back(arg.type);
4962 ubTypes.push_back(arg.type);
4963 stepTypes.push_back(arg.type);
4964 }
4965
4966 return parser.parseRegion(region, ivArgs);
4967}
4968
4970 ValueRange lbs, ValueRange ubs,
4972 TypeRange) {
4973 Block &entry = region.front();
4974
4975 for (unsigned i = 0, e = entry.getNumArguments(); i < e; ++i) {
4976 if (i != 0)
4977 p << ", ";
4978 p.printRegionArgument(entry.getArgument(i));
4979 }
4980 p << ") = (";
4981
4982 // (%lb0 to %ub0 step %step0, %lb1 to %ub1 step %step1, ...)
4983 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
4984 if (i)
4985 p << ", ";
4986 p << lbs[i] << " to " << ubs[i] << " step " << steps[i];
4987 }
4988 p << ") ";
4989
4990 p.printRegion(region, /*printEntryBlockArgs=*/false,
4991 /*printBlockTerminators=*/true);
4992}
4993
4994LogicalResult IteratorOp::verify() {
4995 auto iteratedTy = llvm::dyn_cast<omp::IteratedType>(getIterated().getType());
4996 if (!iteratedTy)
4997 return emitOpError() << "result must be omp.iterated<entry_ty>";
4998
4999 for (auto [lb, ub, step] : llvm::zip_equal(
5000 getLoopLowerBounds(), getLoopUpperBounds(), getLoopSteps())) {
5001 if (matchPattern(step, m_Zero()))
5002 return emitOpError() << "loop step must not be zero";
5003
5004 IntegerAttr lbAttr;
5005 IntegerAttr ubAttr;
5006 IntegerAttr stepAttr;
5007 if (!matchPattern(lb, m_Constant(&lbAttr)) ||
5008 !matchPattern(ub, m_Constant(&ubAttr)) ||
5009 !matchPattern(step, m_Constant(&stepAttr)))
5010 continue;
5011
5012 const APInt &lbVal = lbAttr.getValue();
5013 const APInt &ubVal = ubAttr.getValue();
5014 const APInt &stepVal = stepAttr.getValue();
5015 if (stepVal.isStrictlyPositive() && lbVal.sgt(ubVal))
5016 return emitOpError() << "positive loop step requires lower bound to be "
5017 "less than or equal to upper bound";
5018 if (stepVal.isNegative() && lbVal.slt(ubVal))
5019 return emitOpError() << "negative loop step requires lower bound to be "
5020 "greater than or equal to upper bound";
5021 }
5022
5023 Block &b = getRegion().front();
5024 auto yield = llvm::dyn_cast<omp::YieldOp>(b.getTerminator());
5025
5026 if (!yield)
5027 return emitOpError() << "region must be terminated by omp.yield";
5028
5029 if (yield.getNumOperands() != 1)
5030 return emitOpError()
5031 << "omp.yield in omp.iterator region must yield exactly one value";
5032
5033 mlir::Type yieldedTy = yield.getOperand(0).getType();
5034 mlir::Type elemTy = iteratedTy.getElementType();
5035
5036 if (yieldedTy != elemTy)
5037 return emitOpError() << "omp.iterated element type (" << elemTy
5038 << ") does not match omp.yield operand type ("
5039 << yieldedTy << ")";
5040
5041 return success();
5042}
5043
5044#define GET_ATTRDEF_CLASSES
5045#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
5046
5047#define GET_OP_CLASSES
5048#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
5049
5050#define GET_TYPEDEF_CLASSES
5051#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition EmitC.cpp:1497
static Type getElementType(Type type)
Determine the element type of type.
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds, OperandRange iteratedVars, TypeRange iteratedTypes, std::optional< ArrayAttr > iteratedKinds)
Print Depend clause.
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool > > reductionByref)
Verifies Reduction Clause.
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars, SmallVectorImpl< Type > &linearStepTypes, ArrayAttr &linearModifiers)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static ParseResult parseAffinityClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iterated, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &affinityVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< Type > &affinityVarTypes)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printSplitIteratedList(OpAsmPrinter &p, ValueRange iteratedVars, TypeRange iteratedTypes, ValueRange plainVars, TypeRange plainTypes, PrintPrefixFn &&printPrefixForPlain, PrintPrefixFn &&printPrefixForIterated)
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars, std::optional< ArrayAttr > iteratedKinds, OperandRange iteratedVars)
Verifies Depend clause.
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printAffinityClause(OpAsmPrinter &p, Operation *op, ValueRange iterated, ValueRange affinityVars, TypeRange iteratedTypes, TypeRange affinityVarTypes)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static void printIteratorHeader(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lbs, ValueRange ubs, ValueRange steps, TypeRange, TypeRange, TypeRange)
static ParseResult parseIteratorHeader(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lbs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &ubs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &steps, SmallVectorImpl< Type > &lbTypes, SmallVectorImpl< Type > &ubTypes, SmallVectorImpl< Type > &stepTypes)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static ParseResult parseUniformClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &uniformVars, SmallVectorImpl< Type > &uniformTypes)
uniform ::= uniform ( uniform-list ) uniform-list := uniform-val (, uniform-val)* uniform-val := ssa-...
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
operation ::= res = (omp.target_alloc_mem) $device : devicetype, $in_type ( ( $typeparams ) )?...
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, ArrayAttr &iteratedKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static Operation * getParentInSameDialect(Operation *thisOp)
static void printUniformClause(OpAsmPrinter &p, Operation *op, ValueRange uniformVars, TypeRange uniformTypes)
Print Uniform Clauses.
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static ParseResult parseSplitIteratedList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &plainVars, SmallVectorImpl< Type > &plainTypes, ParsePrefixFn &&parsePrefix)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
return success()
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static LogicalResult verifyLinearModifiers(Operation *op, std::optional< ArrayAttr > linearModifiers, OperandRange linearVars, bool isDeclareSimd=false)
OpenMP 5.2, Section 5.4.6: "A linear-modifier may be specified as ref or uval only on a declare simd ...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower, OperandRange numTeamsUpperVars)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars, TypeRange stepVarTypes, ArrayAttr linearModifiers)
Print Linear Clause.
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static LogicalResult checkApplyeesNesting(TileOp op)
Check properties of the loop nest consisting of the transformation's applyees:
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 > > &modifiers)
static bool isUnique(It begin, It end)
Definition ShardOps.cpp:161
static LogicalResult emit(SolverOp solver, const SMTEmissionOptions &options, mlir::raw_indented_ostream &stream)
Emit the SMT operations in the given 'solver' to the 'stream'.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
SuccessorRange getSuccessors()
Definition Block.h:280
BlockArgListType getArguments()
Definition Block.h:97
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerType getI64Type()
Definition Builders.cpp:69
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A class for computing basic dominance information.
Definition Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
Definition Builders.h:209
This class represents an operand of an operation.
Definition Value.h:254
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:712
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:700
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:703
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
user_range getUsers()
Returns a range of all users.
Definition Operation.h:899
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:248
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition Region.h:170
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:233
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
Location getLoc()
Return a location for this region.
Definition Region.cpp:31
BlockArgument getArgument(unsigned i)
Definition Region.h:124
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
SideEffects::EffectInstance< Effect > EffectInstance
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
function_ref< void(Block *, StringRef)> OpAsmSetBlockNameFn
A functor used to set the name of blocks in regions directly nested under an operation.
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.