MLIR 23.0.0git
OpenACCCG.cpp
Go to the documentation of this file.
1//===- OpenACCCG.cpp - OpenACC codegen ops, attributes, and types ---------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implementation for OpenACC codegen operations, attributes, and types.
10// These correspond to the definitions in OpenACCCG*.td tablegen files
11// and are kept in a separate file because they do not represent direct mappings
12// of OpenACC language constructs; they are intermediate representations used
13// when decomposing and lowering primary `acc` dialect operations.
14//
15//===----------------------------------------------------------------------===//
16
23#include "mlir/IR/Region.h"
27#include "llvm/ADT/STLExtras.h"
28#include "llvm/ADT/SmallVector.h"
29
30using namespace mlir;
31using namespace acc;
32
33namespace {
34
35/// Generic helper for single-region OpenACC ops that execute their body once
36/// and then return to the parent operation with their results (if any).
37static void
41 if (point.isParent()) {
42 regions.push_back(RegionSuccessor(&region));
43 return;
44 }
45 regions.push_back(RegionSuccessor::parent());
46}
47
49 RegionSuccessor successor) {
50 return successor.isParent() ? ValueRange(op->getResults()) : ValueRange();
51}
52
53/// Remove empty acc.kernel_environment operations. If the operation has wait
54/// operands, create a acc.wait operation to preserve synchronization.
55struct RemoveEmptyKernelEnvironment
56 : public OpRewritePattern<acc::KernelEnvironmentOp> {
57 using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
58
59 LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
60 PatternRewriter &rewriter) const override {
61 assert(op->getNumRegions() == 1 && "expected op to have one region");
62
63 Block &block = op.getRegion().front();
64 if (!block.empty())
65 return failure();
66
67 // Remove empty kernel environment.
68 // Preserve synchronization by creating acc.wait operation if needed.
69 if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
70 rewriter.replaceOpWithNewOp<acc::WaitOp>(
71 op, op.getWaitOperands(), /*asyncOperand=*/Value(),
72 op.getWaitDevnum(), /*async=*/nullptr, /*ifCond=*/Value());
73 else
74 rewriter.eraseOp(op);
75
76 return success();
77 }
78};
79
80static void updateComputeRegionInputOperandSegments(ComputeRegionOp op,
81 PatternRewriter &rewriter,
82 size_t numInput) {
83 const size_t numLaunch = op.getLaunchArgs().size();
84 op->setAttr(ComputeRegionOp::getOperandSegmentSizeAttr(),
85 rewriter.getDenseI32ArrayAttr({static_cast<int32_t>(numLaunch),
86 static_cast<int32_t>(numInput),
87 op.getStream() ? 1 : 0}));
88}
89
90struct ComputeRegionRemoveDuplicateArgs
91 : public OpRewritePattern<ComputeRegionOp> {
93
94 LogicalResult matchAndRewrite(ComputeRegionOp op,
95 PatternRewriter &rewriter) const override {
96 Block *body = op.getBody();
97 const size_t numLaunch = op.getLaunchArgs().size();
98 size_t numInput = op.getInputArgs().size();
99 assert(body->getNumArguments() == numLaunch + numInput &&
100 "region args mismatch");
101
102 bool mergedAny = false;
103 while (true) {
104 bool merged = false;
105 for (size_t j = 1; j < numInput && !merged; ++j) {
106 for (size_t i = 0; i < j; ++i) {
107 if (op->getOperand(static_cast<unsigned>(numLaunch + i)) !=
108 op->getOperand(static_cast<unsigned>(numLaunch + j)))
109 continue;
110 unsigned keepIdx = static_cast<unsigned>(numLaunch + i);
111 unsigned dropIdx = static_cast<unsigned>(numLaunch + j);
112 rewriter.replaceAllUsesWith(body->getArgument(dropIdx),
113 body->getArgument(keepIdx));
114 body->eraseArgument(dropIdx);
115 op->eraseOperand(dropIdx);
116 --numInput;
117 merged = true;
118 mergedAny = true;
119 break;
120 }
121 }
122 if (!merged)
123 break;
124 }
125
126 if (!mergedAny)
127 return failure();
128 updateComputeRegionInputOperandSegments(op, rewriter, numInput);
129 return success();
130 }
131};
132
133struct ComputeRegionRemoveUnusedArgs
134 : public OpRewritePattern<ComputeRegionOp> {
136
137 LogicalResult matchAndRewrite(ComputeRegionOp op,
138 PatternRewriter &rewriter) const override {
139 Block *body = op.getBody();
140 const size_t numLaunch = op.getLaunchArgs().size();
141 size_t numInput = op.getInputArgs().size();
142 assert(body->getNumArguments() == numLaunch + numInput &&
143 "region args mismatch");
144
145 bool changed = false;
146 for (size_t k = numLaunch; k < numLaunch + numInput;) {
147 if (!body->getArgument(static_cast<unsigned>(k)).use_empty()) {
148 ++k;
149 continue;
150 }
151 body->eraseArgument(static_cast<unsigned>(k));
152 op->eraseOperand(static_cast<unsigned>(k));
153 --numInput;
154 changed = true;
155 }
156
157 if (!changed)
158 return failure();
159 updateComputeRegionInputOperandSegments(op, rewriter, numInput);
160 return success();
161 }
162};
163
164template <typename EffectTy>
165static void addOperandEffect(
167 &effects,
168 const MutableOperandRange &operand) {
169 for (unsigned i = 0, e = operand.size(); i < e; ++i)
170 effects.emplace_back(EffectTy::get(), &operand[i]);
171}
172
173template <typename EffectTy>
174static void addResultEffect(
176 &effects,
177 Value result) {
178 effects.emplace_back(EffectTy::get(), mlir::cast<mlir::OpResult>(result));
179}
180
181static int64_t gpuProcessorIndex(gpu::Processor p) {
182 switch (p) {
183 case gpu::Processor::Sequential:
184 return 0;
185 case gpu::Processor::ThreadX:
186 return 1;
187 case gpu::Processor::ThreadY:
188 return 2;
189 case gpu::Processor::ThreadZ:
190 return 3;
191 case gpu::Processor::BlockX:
192 return 4;
193 case gpu::Processor::BlockY:
194 return 5;
195 case gpu::Processor::BlockZ:
196 return 6;
197 }
198 llvm_unreachable("unhandled gpu::Processor");
199}
200
201static gpu::Processor indexToGpuProcessor(int64_t idx) {
202 switch (idx) {
203 case 0:
204 return gpu::Processor::Sequential;
205 case 1:
206 return gpu::Processor::ThreadX;
207 case 2:
208 return gpu::Processor::ThreadY;
209 case 3:
210 return gpu::Processor::ThreadZ;
211 case 4:
212 return gpu::Processor::BlockX;
213 case 5:
214 return gpu::Processor::BlockY;
215 case 6:
216 return gpu::Processor::BlockZ;
217 default:
218 return gpu::Processor::Sequential;
219 }
220}
221
222static GPUParallelDimAttr intToParDim(MLIRContext *context, int64_t dimInt) {
223 return GPUParallelDimAttr::get(
224 context, IntegerAttr::get(IndexType::get(context), dimInt));
225}
226
227static GPUParallelDimAttr processorParDim(MLIRContext *context,
228 gpu::Processor proc) {
229 return GPUParallelDimAttr::get(
230 context,
231 IntegerAttr::get(IndexType::get(context), gpuProcessorIndex(proc)));
232}
233
234static ParseResult parseProcessorValue(AsmParser &parser,
235 GPUParallelDimAttr &dim) {
236 std::string keyword;
237 llvm::SMLoc loc = parser.getCurrentLocation();
238 if (failed(parser.parseKeywordOrString(&keyword)))
239 return failure();
240 auto maybeProcessor = gpu::symbolizeProcessor(keyword);
241 if (!maybeProcessor)
242 return parser.emitError(loc)
243 << "expected one of ::mlir::gpu::Processor enum names";
244 dim = intToParDim(parser.getContext(), gpuProcessorIndex(*maybeProcessor));
245 return success();
246}
247
248static void printProcessorValue(AsmPrinter &printer,
249 const GPUParallelDimAttr &attr) {
250 gpu::Processor processor = indexToGpuProcessor(attr.getValue().getInt());
251 printer << gpu::stringifyProcessor(processor);
252}
253
254} // namespace
255
256//===----------------------------------------------------------------------===//
257// KernelEnvironmentOp
258//===----------------------------------------------------------------------===//
259
260void KernelEnvironmentOp::getSuccessorRegions(
262 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
263 regions);
264}
265
266ValueRange KernelEnvironmentOp::getSuccessorInputs(RegionSuccessor successor) {
267 return getSingleRegionSuccessorInputs(getOperation(), successor);
268}
269
270void KernelEnvironmentOp::getCanonicalizationPatterns(
271 RewritePatternSet &results, MLIRContext *context) {
272 results.add<RemoveEmptyKernelEnvironment>(context);
273}
274
275/// Extract async for `clauseDeviceType`. Returns true if a clause was found.
276template <typename ComputeConstructT>
277static bool
278extractAsyncClause(ComputeConstructT computeConstruct,
279 DeviceType clauseDeviceType, MLIRContext *context,
280 std::optional<Value> &asyncOperand, UnitAttr &asyncOnly) {
281 if (computeConstruct.hasAsyncOnly(clauseDeviceType)) {
282 asyncOnly = UnitAttr::get(context);
283 return true;
284 }
285 if (Value asyncValue = computeConstruct.getAsyncValue(clauseDeviceType)) {
286 asyncOperand = asyncValue;
287 return true;
288 }
289 return false;
290}
291
292/// Extract wait for `clauseDeviceType`. Returns true if a clause was found.
293template <typename ComputeConstructT>
294static bool extractWaitClause(ComputeConstructT computeConstruct,
295 DeviceType clauseDeviceType, MLIRContext *context,
296 std::optional<Value> &waitDevnum,
297 SmallVectorImpl<Value> &waitOperands,
298 UnitAttr &waitOnly) {
299 if (computeConstruct.hasWaitOnly(clauseDeviceType)) {
300 waitOnly = UnitAttr::get(context);
301 return true;
302 }
303 Value devnum = computeConstruct.getWaitDevnum(clauseDeviceType);
304 auto waitValues = computeConstruct.getWaitValues(clauseDeviceType);
305 if (!devnum && waitValues.empty())
306 return false;
307 if (devnum)
308 waitDevnum = devnum;
309 waitOperands.append(waitValues.begin(), waitValues.end());
310 return true;
311}
312
313template <typename ComputeConstructT>
315 ComputeConstructT computeConstruct, DeviceType deviceType,
316 std::optional<Value> &asyncOperand, UnitAttr &asyncOnly,
317 std::optional<Value> &waitDevnum, SmallVectorImpl<Value> &waitOperands,
318 UnitAttr &waitOnly) {
319 MLIRContext *context = computeConstruct->getContext();
320
321 // Prefer device_type-specific clauses, then default ones.
322 if (!extractAsyncClause(computeConstruct, deviceType, context, asyncOperand,
323 asyncOnly)) {
324 if (deviceType != DeviceType::None)
325 extractAsyncClause(computeConstruct, DeviceType::None, context,
326 asyncOperand, asyncOnly);
327 }
328
329 if (!extractWaitClause(computeConstruct, deviceType, context, waitDevnum,
330 waitOperands, waitOnly)) {
331 if (deviceType != DeviceType::None)
332 extractWaitClause(computeConstruct, DeviceType::None, context, waitDevnum,
333 waitOperands, waitOnly);
334 }
335}
336
337template <typename ComputeConstructT>
338KernelEnvironmentOp
339KernelEnvironmentOp::createAndPopulate(ComputeConstructT computeConstruct,
340 DeviceType deviceType,
341 OpBuilder &builder) {
342 std::optional<Value> asyncOperand;
343 UnitAttr asyncOnly = nullptr;
344 std::optional<Value> waitDevnum;
345 SmallVector<Value> waitOperands;
346 UnitAttr waitOnly = nullptr;
347 populateKernelEnvironmentAsyncWait(computeConstruct, deviceType, asyncOperand,
348 asyncOnly, waitDevnum, waitOperands,
349 waitOnly);
350
351 auto kernelEnvironment = KernelEnvironmentOp::create(
352 builder, computeConstruct->getLoc(),
353 computeConstruct.getDataClauseOperands(), asyncOperand.value_or(Value()),
354 asyncOnly, waitDevnum.value_or(Value()), waitOperands, waitOnly);
355 Block &block = kernelEnvironment.getRegion().emplaceBlock();
356 builder.setInsertionPointToStart(&block);
357 return kernelEnvironment;
358}
359
360template KernelEnvironmentOp
361KernelEnvironmentOp::createAndPopulate<ParallelOp>(ParallelOp, DeviceType,
362 OpBuilder &);
363template KernelEnvironmentOp
364KernelEnvironmentOp::createAndPopulate<KernelsOp>(KernelsOp, DeviceType,
365 OpBuilder &);
366template KernelEnvironmentOp
367KernelEnvironmentOp::createAndPopulate<SerialOp>(SerialOp, DeviceType,
368 OpBuilder &);
369
370LogicalResult KernelEnvironmentOp::verify() {
371 if (getAsyncOnly() && getAsyncOperand())
372 return emitError("async-only cannot appear with async operand");
373 if (getWaitOnly() && (!getWaitOperands().empty() || getWaitDevnum()))
374 return emitError("wait-only cannot appear with wait operands or devnum");
375 return success();
376}
377
378//===----------------------------------------------------------------------===//
379// FirstprivateMapInitialOp
380//===----------------------------------------------------------------------===//
381
382LogicalResult FirstprivateMapInitialOp::verify() {
383 if (getDataClause() != acc::DataClause::acc_firstprivate)
384 return emitError("data clause associated with firstprivate operation must "
385 "match its intent");
386 if (!getVar())
387 return emitError("must have var operand");
388 if (!mlir::isa<mlir::acc::PointerLikeType>(getVar().getType()) &&
389 !mlir::isa<mlir::acc::MappableType>(getVar().getType()))
390 return emitError("var must be mappable or pointer-like");
391 if (mlir::isa<mlir::acc::PointerLikeType>(getVar().getType()) &&
392 getVarType() == getVar().getType())
393 return emitError("varType must capture the element type of var");
394 if (getModifiers() != acc::DataClauseModifier::none)
395 return emitError("no data clause modifiers are allowed");
396 return success();
397}
398
399void FirstprivateMapInitialOp::getEffects(
401 &effects) {
402 effects.emplace_back(MemoryEffects::Read::get(),
404 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
406}
407
408//===----------------------------------------------------------------------===//
409// ReductionInitOp
410//===----------------------------------------------------------------------===//
411
412void ReductionInitOp::getSuccessorRegions(
414 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
415 regions);
416}
417
418void ReductionInitOp::getRegionInvocationBounds(
419 ArrayRef<Attribute> operands,
420 SmallVectorImpl<InvocationBounds> &invocationBounds) {
421 invocationBounds.emplace_back(1, 1);
422}
423
424ValueRange ReductionInitOp::getSuccessorInputs(RegionSuccessor successor) {
425 return getSingleRegionSuccessorInputs(getOperation(), successor);
426}
427
428LogicalResult ReductionInitOp::verify() {
429 Block &block = getRegion().front();
430 if (auto yieldOp = dyn_cast<acc::YieldOp>(block.getTerminator())) {
431 if (yieldOp.getNumOperands() != 1)
432 return emitOpError(
433 "region must yield exactly one value (private storage)");
434 if (yieldOp.getOperand(0).getType() != getVar().getType())
435 return emitOpError("yielded value type must match var type");
436 }
437 return success();
438}
439
440//===----------------------------------------------------------------------===//
441// ReductionCombineRegionOp
442//===----------------------------------------------------------------------===//
443
444void ReductionCombineRegionOp::getSuccessorRegions(
446 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
447 regions);
448}
449
450void ReductionCombineRegionOp::getRegionInvocationBounds(
451 ArrayRef<Attribute> operands,
452 SmallVectorImpl<InvocationBounds> &invocationBounds) {
453 invocationBounds.emplace_back(1, 1);
454}
455
457ReductionCombineRegionOp::getSuccessorInputs(RegionSuccessor successor) {
458 return getSingleRegionSuccessorInputs(getOperation(), successor);
459}
460
461LogicalResult ReductionCombineRegionOp::verify() {
462 Block &block = getRegion().front();
463 if (auto yieldOp = dyn_cast<acc::YieldOp>(block.getTerminator())) {
464 if (yieldOp.getNumOperands() != 0)
465 return emitOpError("region must be terminated by acc.yield with no "
466 "operands");
467 }
468 return success();
469}
470
471//===----------------------------------------------------------------------===//
472// ReductionCombineOp
473//===----------------------------------------------------------------------===//
474
475void ReductionCombineOp::getEffects(
477 &effects) {
478 effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemrefMutable(),
480 effects.emplace_back(MemoryEffects::Read::get(), &getDestMemrefMutable(),
482 effects.emplace_back(MemoryEffects::Write::get(), &getDestMemrefMutable(),
484}
485
486//===----------------------------------------------------------------------===//
487// ComputeRegionOp
488//===----------------------------------------------------------------------===//
489
490static ParWidthOp getParWidthOpForLaunchArg(ComputeRegionOp op,
491 GPUParallelDimAttr parDim) {
492 for (auto launchArg : op.getLaunchArgs()) {
493 auto parOp = launchArg.getDefiningOp<ParWidthOp>();
494 if (!parOp)
495 continue;
496 auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
497 if (launchArgDim == parDim)
498 return parOp;
499 }
500 return nullptr;
501}
502
503std::optional<Value> ComputeRegionOp::getLaunchArg(GPUParallelDimAttr parDim) {
504 if (auto parWidthOp = getParWidthOpForLaunchArg(*this, parDim))
505 return parWidthOp.getResult();
506 return {};
507}
508
509std::optional<Value>
510ComputeRegionOp::getKnownLaunchArg(GPUParallelDimAttr parDim) {
511 if (auto parWidthOp = getParWidthOpForLaunchArg(*this, parDim))
512 if (parWidthOp.getLaunchArg())
513 return parWidthOp.getLaunchArg();
514 return {};
515}
516
517std::optional<uint64_t>
518ComputeRegionOp::getKnownConstantLaunchArg(GPUParallelDimAttr parDim) {
519 auto knownParWidth = getKnownLaunchArg(parDim);
520 if (knownParWidth.has_value())
521 return getConstantIntValue(knownParWidth.value());
522 return {};
523}
524
525BlockArgument ComputeRegionOp::appendInputArg(Value value) {
526 getInputArgsMutable().append(value);
527 return getBody()->addArgument(value.getType(), getLoc());
528}
529
530std::optional<BlockArgument>
531ComputeRegionOp::wireHoistedValueThroughIns(Value value) {
532 Region &region = getRegion();
533
534 auto useIsInRegion = [&](OpOperand &use) -> bool {
535 return region.isAncestor(use.getOwner()->getParentRegion());
536 };
537
538 if (!areValuesDefinedAbove(ValueRange(value), region) ||
539 !llvm::any_of(value.getUses(), useIsInRegion))
540 return std::nullopt;
541
542 BlockArgument arg = appendInputArg(value);
543 replaceAllUsesInRegionWith(value, arg, region);
544 return arg;
545}
546
547bool ComputeRegionOp::isEffectivelySerial() {
548 auto *ctx = getContext();
549
550 if (getLaunchArg(GPUParallelDimAttr::seqDim(ctx)))
551 return true;
552
553 auto checkDim = [&](GPUParallelDimAttr dim) -> bool {
554 auto val = getKnownConstantLaunchArg(dim);
555 return val && *val == 1;
556 };
557
558 return checkDim(GPUParallelDimAttr::threadXDim(ctx)) &&
559 checkDim(GPUParallelDimAttr::threadYDim(ctx)) &&
560 checkDim(GPUParallelDimAttr::threadZDim(ctx)) &&
561 checkDim(GPUParallelDimAttr::blockXDim(ctx)) &&
562 checkDim(GPUParallelDimAttr::blockYDim(ctx)) &&
563 checkDim(GPUParallelDimAttr::blockZDim(ctx));
564}
565
566BlockArgument ComputeRegionOp::parDimToWidth(GPUParallelDimAttr parDim) {
567 for (auto [pos, launchArg] : llvm::enumerate(getLaunchArgs())) {
568 auto parOp = launchArg.getDefiningOp<ParWidthOp>();
569 assert(parOp);
570 auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
571 if (launchArgDim == parDim) {
572 assert(pos < getRegion().front().getNumArguments() &&
573 "launch arg position out of range");
574 return getRegion().front().getArgument(pos);
575 }
576 }
577 llvm_unreachable("attempting to get unspecified parDim");
578}
579
580SmallVector<GPUParallelDimAttr> ComputeRegionOp::getLaunchParDims() {
582 for (auto launchArg : getLaunchArgs()) {
583 auto parOp = launchArg.getDefiningOp<ParWidthOp>();
584 auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
585 int64_t dimInt = launchArgDim.getValue().getInt();
586 parDims.push_back(intToParDim(getContext(), dimInt));
587 }
588 return parDims;
589}
590
591Value ComputeRegionOp::getOperand(BlockArgument blockArg) {
592 Block *body = getBody();
593 if (blockArg.getOwner() != body)
594 return Value();
595 unsigned argNumber = blockArg.getArgNumber();
596 unsigned numLaunchArgs = getLaunchArgs().size();
597 unsigned numInputArgs = getInputArgs().size();
598 if (argNumber >= numLaunchArgs + numInputArgs)
599 return Value();
600 if (argNumber < numLaunchArgs)
601 return getLaunchArgs()[argNumber];
602 return getInputArgs()[argNumber - numLaunchArgs];
603}
604
605std::optional<BlockArgument> ComputeRegionOp::getBlockArg(Value value) {
606 Block *body = getBody();
607 for (auto [idx, launchVal] : llvm::enumerate(getLaunchArgs())) {
608 if (launchVal == value)
609 return body->getArgument(idx);
610 }
611 unsigned numLaunch = getLaunchArgs().size();
612 for (auto [idx, inputVal] : llvm::enumerate(getInputArgs())) {
613 if (inputVal == value)
614 return body->getArgument(numLaunch + idx);
615 }
616 return std::nullopt;
617}
618
619void ComputeRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
620 MLIRContext *context) {
621 results.add<ComputeRegionRemoveDuplicateArgs, ComputeRegionRemoveUnusedArgs>(
622 context);
623}
624
625BlockArgument ComputeRegionOp::gpuParWidth(gpu::Processor processor) {
626 return parDimToWidth(GPUParallelDimAttr::get(getContext(), processor));
627}
628
629LogicalResult ComputeRegionOp::verify() {
630 for (auto op : getLaunchArgs())
631 if (!op.getDefiningOp<acc::ParWidthOp>())
632 return emitOpError(
633 "launch arguments must be results of acc.par_width operations");
634
635 unsigned expectedBlockArgs = getLaunchArgs().size() + getInputArgs().size();
636 unsigned actualBlockArgs = getRegion().front().getNumArguments();
637 if (expectedBlockArgs != actualBlockArgs)
638 return emitOpError("expected ")
639 << expectedBlockArgs << " block arguments (launch + input), got "
640 << actualBlockArgs;
641
642 return success();
643}
644
645void ComputeRegionOp::print(OpAsmPrinter &p) {
646 ValueRange regionArgs = getBody()->getArguments();
647 ValueRange launchArgs = getLaunchArgs();
648 ValueRange inputArgs = getInputArgs();
649
650 assert(regionArgs.size() == (launchArgs.size() + inputArgs.size()) &&
651 "region args mismatch");
652
653 if (getStream())
654 p << " stream(" << getStream() << " : " << getStream().getType() << ")";
655
656 size_t i = 0;
657 if (!launchArgs.empty()) {
658 p << " launch(";
659 for (size_t j = 0; j < launchArgs.size(); ++j, ++i) {
660 p << regionArgs[i] << " = " << launchArgs[j];
661 if (j < launchArgs.size() - 1)
662 p << ", ";
663 }
664 p << ")";
665 }
666 if (!inputArgs.empty()) {
667 p << " ins(";
668 for (size_t j = 0; j < inputArgs.size(); ++j, ++i) {
669 p << regionArgs[i] << " = " << inputArgs[j];
670 if (j < inputArgs.size() - 1)
671 p << ", ";
672 }
673 p << ") : (";
674 for (size_t j = 0; j < inputArgs.size(); ++j) {
675 p << inputArgs[j].getType();
676 if (j < inputArgs.size() - 1)
677 p << ", ";
678 }
679 p << ")";
680 }
681 p.printOptionalArrowTypeList(getResultTypes());
682 p << " ";
683 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
684 p.printOptionalAttrDict((*this)->getAttrs(),
685 /*elidedAttrs=*/getOperandSegmentSizeAttr());
686}
687
688ParseResult ComputeRegionOp::parse(OpAsmParser &parser,
690 auto &builder = parser.getBuilder();
691
693 OpAsmParser::UnresolvedOperand streamOperand;
694 Type streamType;
697 SmallVector<Type> types;
698
699 bool hasStream = false;
700 if (succeeded(parser.parseOptionalKeyword("stream"))) {
701 hasStream = true;
702 if (parser.parseLParen() || parser.parseOperand(streamOperand) ||
703 parser.parseColon() || parser.parseType(streamType) ||
704 parser.parseRParen())
705 return failure();
706 }
707
708 if (succeeded(parser.parseOptionalKeyword("launch"))) {
709 if (parser.parseAssignmentList(regionArgs, launchOperands))
710 return failure();
711 Type indexType = builder.getIndexType();
712 for (size_t i = 0; i < regionArgs.size(); ++i)
713 types.push_back(indexType);
714 }
715
716 if (succeeded(parser.parseOptionalKeyword("ins"))) {
717 if (parser.parseAssignmentList(regionArgs, inputOperands) ||
718 parser.parseColon() || parser.parseLParen() ||
719 parser.parseTypeList(types) || parser.parseRParen())
720 return failure();
721 }
722
723 if (parser.parseOptionalArrowTypeList(result.types))
724 return failure();
725
726 for (auto [iterArg, type] : llvm::zip_equal(regionArgs, types))
727 iterArg.type = type;
728
729 Region *body = result.addRegion();
730 if (parser.parseRegion(*body, regionArgs))
731 return failure();
732 ComputeRegionOp::ensureTerminator(*body, parser.getBuilder(),
733 result.location);
734
735 const size_t numLaunchOperands = launchOperands.size();
736 const size_t numInputOperands = inputOperands.size();
737 assert(numLaunchOperands + numInputOperands == regionArgs.size() &&
738 "compute region args mismatch");
739
740 result.addAttribute(
741 ComputeRegionOp::getOperandSegmentSizeAttr(),
742 builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLaunchOperands),
743 static_cast<int32_t>(numInputOperands),
744 hasStream ? 1 : 0}));
745
746 for (size_t i = 0; i < numLaunchOperands; ++i) {
747 if (parser.resolveOperand(launchOperands[i], types[i], result.operands))
748 return failure();
749 }
750
751 for (size_t i = numLaunchOperands; i < regionArgs.size(); ++i) {
752 if (parser.resolveOperand(inputOperands[i - numLaunchOperands], types[i],
753 result.operands))
754 return failure();
755 }
756
757 if (hasStream) {
758 if (parser.resolveOperand(streamOperand, streamType, result.operands))
759 return failure();
760 }
761
762 if (parser.parseOptionalAttrDict(result.attributes))
763 return failure();
764
765 return success();
766}
767
768//===----------------------------------------------------------------------===//
769// GPUParallelDimAttr
770//===----------------------------------------------------------------------===//
771
772GPUParallelDimAttr GPUParallelDimAttr::get(MLIRContext *context,
773 gpu::Processor proc) {
774 return processorParDim(context, proc);
775}
776
777GPUParallelDimAttr GPUParallelDimAttr::seqDim(MLIRContext *context) {
778 return processorParDim(context, gpu::Processor::Sequential);
779}
780
781GPUParallelDimAttr GPUParallelDimAttr::threadXDim(MLIRContext *context) {
782 return processorParDim(context, gpu::Processor::ThreadX);
783}
784
785GPUParallelDimAttr GPUParallelDimAttr::threadYDim(MLIRContext *context) {
786 return processorParDim(context, gpu::Processor::ThreadY);
787}
788
789GPUParallelDimAttr GPUParallelDimAttr::threadZDim(MLIRContext *context) {
790 return processorParDim(context, gpu::Processor::ThreadZ);
791}
792
793GPUParallelDimAttr GPUParallelDimAttr::blockXDim(MLIRContext *context) {
794 return processorParDim(context, gpu::Processor::BlockX);
795}
796
797GPUParallelDimAttr GPUParallelDimAttr::blockYDim(MLIRContext *context) {
798 return processorParDim(context, gpu::Processor::BlockY);
799}
800
801GPUParallelDimAttr GPUParallelDimAttr::blockZDim(MLIRContext *context) {
802 return processorParDim(context, gpu::Processor::BlockZ);
803}
804
805Attribute GPUParallelDimAttr::parse(AsmParser &parser, Type type) {
806 GPUParallelDimAttr dim;
807 if (parser.parseLess() || parseProcessorValue(parser, dim) ||
808 parser.parseGreater()) {
809 parser.emitError(parser.getCurrentLocation(),
810 "expected format `<` processor_name `>`");
811 return {};
812 }
813 return dim;
814}
815
816void GPUParallelDimAttr::print(AsmPrinter &printer) const {
817 printer << "<";
818 printProcessorValue(printer, *this);
819 printer << ">";
820}
821
822GPUParallelDimAttr GPUParallelDimAttr::threadDim(MLIRContext *context,
823 unsigned index) {
824 assert(index <= 2 && "thread dimension index must be 0, 1, or 2");
825 switch (index) {
826 case 0:
827 return threadXDim(context);
828 case 1:
829 return threadYDim(context);
830 case 2:
831 return threadZDim(context);
832 }
833 llvm_unreachable("validated thread dimension index");
834}
835
836GPUParallelDimAttr GPUParallelDimAttr::blockDim(MLIRContext *context,
837 unsigned index) {
838 assert(index <= 2 && "block dimension index must be 0, 1, or 2");
839 switch (index) {
840 case 0:
841 return blockXDim(context);
842 case 1:
843 return blockYDim(context);
844 case 2:
845 return blockZDim(context);
846 }
847 llvm_unreachable("validated block dimension index");
848}
849
850gpu::Processor GPUParallelDimAttr::getProcessor() const {
851 return indexToGpuProcessor(getValue().getInt());
852}
853
854int GPUParallelDimAttr::getOrder() const {
855 return gpuProcessorIndex(getProcessor());
856}
857
858GPUParallelDimAttr GPUParallelDimAttr::getOneHigher() const {
859 int order = getOrder();
860 if (order >= 6) // BlockZ is the highest
861 return *this;
862 return get(getContext(), indexToGpuProcessor(order + 1));
863}
864
865GPUParallelDimAttr GPUParallelDimAttr::getOneLower() const {
866 int order = getOrder();
867 if (order <= 0) // Sequential is the lowest
868 return *this;
869 return get(getContext(), indexToGpuProcessor(order - 1));
870}
871
872bool GPUParallelDimAttr::isSeq() const {
873 return getProcessor() == gpu::Processor::Sequential;
874}
875bool GPUParallelDimAttr::isThreadX() const {
876 return getProcessor() == gpu::Processor::ThreadX;
877}
878bool GPUParallelDimAttr::isThreadY() const {
879 return getProcessor() == gpu::Processor::ThreadY;
880}
881bool GPUParallelDimAttr::isThreadZ() const {
882 return getProcessor() == gpu::Processor::ThreadZ;
883}
884bool GPUParallelDimAttr::isBlockX() const {
885 return getProcessor() == gpu::Processor::BlockX;
886}
887bool GPUParallelDimAttr::isBlockY() const {
888 return getProcessor() == gpu::Processor::BlockY;
889}
890bool GPUParallelDimAttr::isBlockZ() const {
891 return getProcessor() == gpu::Processor::BlockZ;
892}
893bool GPUParallelDimAttr::isAnyThread() const {
894 return isThreadX() || isThreadY() || isThreadZ();
895}
896bool GPUParallelDimAttr::isAnyBlock() const {
897 return isBlockX() || isBlockY() || isBlockZ();
898}
899
900//===----------------------------------------------------------------------===//
901// GPUParallelDimsAttr
902//===----------------------------------------------------------------------===//
903
904GPUParallelDimsAttr GPUParallelDimsAttr::seq(MLIRContext *ctx) {
905 return GPUParallelDimsAttr::get(ctx, {GPUParallelDimAttr::seqDim(ctx)});
906}
907
908bool GPUParallelDimsAttr::isSeq() const {
909 assert(!getArray().empty() && "no par_dims found");
910 if (getArray().size() == 1) {
911 auto parDim = dyn_cast<GPUParallelDimAttr>(getArray()[0]);
912 assert(parDim && "expected GPUParallelDimAttr");
913 return parDim.isSeq();
914 }
915 return false;
916}
917
918bool GPUParallelDimsAttr::isParallel() const { return !isSeq(); }
919
920bool GPUParallelDimsAttr::isMultiDim() const { return getArray().size() > 1; }
921
922bool GPUParallelDimsAttr::hasAnyBlockLevel() const {
923 return llvm::any_of(
924 getArray(), [](const GPUParallelDimAttr &p) { return p.isAnyBlock(); });
925}
926
927bool GPUParallelDimsAttr::hasOnlyBlockLevel() const {
928 return !getArray().empty() &&
929 llvm::all_of(getArray(), [](const GPUParallelDimAttr &p) {
930 return p.isAnyBlock();
931 });
932}
933
934bool GPUParallelDimsAttr::hasOnlyThreadYLevel() const {
935 return !getArray().empty() &&
936 llvm::all_of(getArray(), [](const GPUParallelDimAttr &p) {
937 return p.isThreadY();
938 });
939}
940
941bool GPUParallelDimsAttr::hasOnlyThreadXLevel() const {
942 return !getArray().empty() &&
943 llvm::all_of(getArray(), [](const GPUParallelDimAttr &p) {
944 return p.isThreadX();
945 });
946}
947
948Attribute GPUParallelDimsAttr::parse(AsmParser &parser, Type type) {
949 auto delimiter = AsmParser::Delimiter::Square;
951 auto parseParDim = [&]() -> ParseResult {
952 GPUParallelDimAttr dim;
953 if (parseProcessorValue(parser, dim))
954 return failure();
955 parDims.push_back(dim);
956 return success();
957 };
958 if (parser.parseCommaSeparatedList(delimiter, parseParDim,
959 "list of OpenACC GPU parallel dimensions"))
960 return {};
961 return GPUParallelDimsAttr::get(parser.getContext(), parDims);
962}
963
964void GPUParallelDimsAttr::print(AsmPrinter &printer) const {
965 printer << "[";
966 llvm::interleaveComma(getArray(), printer,
967 [&printer](const GPUParallelDimAttr &p) {
968 printProcessorValue(printer, p);
969 });
970 printer << "]";
971}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void addOperandEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, MutableOperandRange operand)
Helper to add an effect on an operand, referenced by its mutable range.
Definition OpenACC.cpp:1317
static void addResultEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, Value result)
Helper to add an effect on a result value.
Definition OpenACC.cpp:1327
static void getSingleRegionOpSuccessorRegions(Operation *op, Region &region, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > &regions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
Definition OpenACC.cpp:511
static ValueRange getSingleRegionSuccessorInputs(Operation *op, RegionSuccessor successor)
Definition OpenACC.cpp:522
b getContext())
static ParWidthOp getParWidthOpForLaunchArg(ComputeRegionOp op, GPUParallelDimAttr parDim)
static bool extractWaitClause(ComputeConstructT computeConstruct, DeviceType clauseDeviceType, MLIRContext *context, std::optional< Value > &waitDevnum, SmallVectorImpl< Value > &waitOperands, UnitAttr &waitOnly)
Extract wait for clauseDeviceType. Returns true if a clause was found.
static bool extractAsyncClause(ComputeConstructT computeConstruct, DeviceType clauseDeviceType, MLIRContext *context, std::optional< Value > &asyncOperand, UnitAttr &asyncOnly)
Extract async for clauseDeviceType. Returns true if a clause was found.
static void populateKernelEnvironmentAsyncWait(ComputeConstructT computeConstruct, DeviceType deviceType, std::optional< Value > &asyncOperand, UnitAttr &asyncOnly, std::optional< Value > &waitDevnum, SmallVectorImpl< Value > &waitOperands, UnitAttr &waitOnly)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseKeywordOrString(std::string *result)
Parse a keyword or a quoted string.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:318
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:315
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition Block.cpp:198
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IndexType getIndexType()
Definition Builders.cpp:55
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:119
unsigned size() const
Returns the current size of the range.
Definition ValueRange.h:157
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
This class represents an operand of an operation.
Definition Value.h:254
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
result_range getResults()
Definition Operation.h:440
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:233
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:5234
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:5203
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition OpenACC.cpp:5307
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:5289
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition OpenACC.cpp:5211
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool areValuesDefinedAbove(Range values, Region &limit)
Check if all values in the provided range are defined above the limit region.
Definition RegionUtils.h:26
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.