MLIR 23.0.0git
OpenACCCG.cpp
Go to the documentation of this file.
1//===- OpenACCCG.cpp - OpenACC codegen ops, attributes, and types ---------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implementation for OpenACC codegen operations, attributes, and types.
10// These correspond to the definitions in OpenACCCG*.td tablegen files
11// and are kept in a separate file because they do not represent direct mappings
12// of OpenACC language constructs; they are intermediate representations used
13// when decomposing and lowering primary `acc` dialect operations.
14//
15//===----------------------------------------------------------------------===//
16
23#include "mlir/IR/Region.h"
27#include "llvm/ADT/STLExtras.h"
28#include "llvm/ADT/SmallVector.h"
29
30using namespace mlir;
31using namespace acc;
32
33namespace {
34
35/// Generic helper for single-region OpenACC ops that execute their body once
36/// and then return to the parent operation with their results (if any).
37static void
41 if (point.isParent()) {
42 regions.push_back(RegionSuccessor(&region));
43 return;
44 }
45 regions.push_back(RegionSuccessor::parent());
46}
47
49 RegionSuccessor successor) {
50 return successor.isParent() ? ValueRange(op->getResults()) : ValueRange();
51}
52
53/// Remove empty acc.kernel_environment operations. If the operation has wait
54/// operands, create a acc.wait operation to preserve synchronization.
55struct RemoveEmptyKernelEnvironment
56 : public OpRewritePattern<acc::KernelEnvironmentOp> {
57 using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
58
59 LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
60 PatternRewriter &rewriter) const override {
61 assert(op->getNumRegions() == 1 && "expected op to have one region");
62
63 Block &block = op.getRegion().front();
64 if (!block.empty())
65 return failure();
66
67 // Conservatively disable canonicalization of empty acc.kernel_environment
68 // operations if the wait operands in the kernel_environment cannot be fully
69 // represented by acc.wait operation.
70
71 // Disable canonicalization if device type is not the default
72 if (auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) {
73 for (auto attr : deviceTypeAttr) {
74 if (auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
75 if (dtAttr.getValue() != mlir::acc::DeviceType::None)
76 return failure();
77 }
78 }
79 }
80
81 // Disable canonicalization if any wait segment has a devnum
82 if (auto hasDevnumAttr = op.getHasWaitDevnumAttr()) {
83 for (auto attr : hasDevnumAttr) {
84 if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
85 if (boolAttr.getValue())
86 return failure();
87 }
88 }
89 }
90
91 // Disable canonicalization if there are multiple wait segments
92 if (auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) {
93 if (segmentsAttr.size() > 1)
94 return failure();
95 }
96
97 // Remove empty kernel environment.
98 // Preserve synchronization by creating acc.wait operation if needed.
99 if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
100 rewriter.replaceOpWithNewOp<acc::WaitOp>(op, op.getWaitOperands(),
101 /*asyncOperand=*/Value(),
102 /*waitDevnum=*/Value(),
103 /*async=*/nullptr,
104 /*ifCond=*/Value());
105 else
106 rewriter.eraseOp(op);
107
108 return success();
109 }
110};
111
112static void updateComputeRegionInputOperandSegments(ComputeRegionOp op,
113 PatternRewriter &rewriter,
114 size_t numInput) {
115 const size_t numLaunch = op.getLaunchArgs().size();
116 op->setAttr(ComputeRegionOp::getOperandSegmentSizeAttr(),
117 rewriter.getDenseI32ArrayAttr({static_cast<int32_t>(numLaunch),
118 static_cast<int32_t>(numInput),
119 op.getStream() ? 1 : 0}));
120}
121
122struct ComputeRegionRemoveDuplicateArgs
123 : public OpRewritePattern<ComputeRegionOp> {
125
126 LogicalResult matchAndRewrite(ComputeRegionOp op,
127 PatternRewriter &rewriter) const override {
128 Block *body = op.getBody();
129 const size_t numLaunch = op.getLaunchArgs().size();
130 size_t numInput = op.getInputArgs().size();
131 assert(body->getNumArguments() == numLaunch + numInput &&
132 "region args mismatch");
133
134 bool mergedAny = false;
135 while (true) {
136 bool merged = false;
137 for (size_t j = 1; j < numInput && !merged; ++j) {
138 for (size_t i = 0; i < j; ++i) {
139 if (op->getOperand(static_cast<unsigned>(numLaunch + i)) !=
140 op->getOperand(static_cast<unsigned>(numLaunch + j)))
141 continue;
142 unsigned keepIdx = static_cast<unsigned>(numLaunch + i);
143 unsigned dropIdx = static_cast<unsigned>(numLaunch + j);
144 rewriter.replaceAllUsesWith(body->getArgument(dropIdx),
145 body->getArgument(keepIdx));
146 body->eraseArgument(dropIdx);
147 op->eraseOperand(dropIdx);
148 --numInput;
149 merged = true;
150 mergedAny = true;
151 break;
152 }
153 }
154 if (!merged)
155 break;
156 }
157
158 if (!mergedAny)
159 return failure();
160 updateComputeRegionInputOperandSegments(op, rewriter, numInput);
161 return success();
162 }
163};
164
165struct ComputeRegionRemoveUnusedArgs
166 : public OpRewritePattern<ComputeRegionOp> {
168
169 LogicalResult matchAndRewrite(ComputeRegionOp op,
170 PatternRewriter &rewriter) const override {
171 Block *body = op.getBody();
172 const size_t numLaunch = op.getLaunchArgs().size();
173 size_t numInput = op.getInputArgs().size();
174 assert(body->getNumArguments() == numLaunch + numInput &&
175 "region args mismatch");
176
177 bool changed = false;
178 for (size_t k = numLaunch; k < numLaunch + numInput;) {
179 if (!body->getArgument(static_cast<unsigned>(k)).use_empty()) {
180 ++k;
181 continue;
182 }
183 body->eraseArgument(static_cast<unsigned>(k));
184 op->eraseOperand(static_cast<unsigned>(k));
185 --numInput;
186 changed = true;
187 }
188
189 if (!changed)
190 return failure();
191 updateComputeRegionInputOperandSegments(op, rewriter, numInput);
192 return success();
193 }
194};
195
196template <typename EffectTy>
197static void addOperandEffect(
199 &effects,
200 const MutableOperandRange &operand) {
201 for (unsigned i = 0, e = operand.size(); i < e; ++i)
202 effects.emplace_back(EffectTy::get(), &operand[i]);
203}
204
205template <typename EffectTy>
206static void addResultEffect(
208 &effects,
209 Value result) {
210 effects.emplace_back(EffectTy::get(), mlir::cast<mlir::OpResult>(result));
211}
212
213static int64_t gpuProcessorIndex(gpu::Processor p) {
214 switch (p) {
215 case gpu::Processor::Sequential:
216 return 0;
217 case gpu::Processor::ThreadX:
218 return 1;
219 case gpu::Processor::ThreadY:
220 return 2;
221 case gpu::Processor::ThreadZ:
222 return 3;
223 case gpu::Processor::BlockX:
224 return 4;
225 case gpu::Processor::BlockY:
226 return 5;
227 case gpu::Processor::BlockZ:
228 return 6;
229 }
230 llvm_unreachable("unhandled gpu::Processor");
231}
232
233static gpu::Processor indexToGpuProcessor(int64_t idx) {
234 switch (idx) {
235 case 0:
236 return gpu::Processor::Sequential;
237 case 1:
238 return gpu::Processor::ThreadX;
239 case 2:
240 return gpu::Processor::ThreadY;
241 case 3:
242 return gpu::Processor::ThreadZ;
243 case 4:
244 return gpu::Processor::BlockX;
245 case 5:
246 return gpu::Processor::BlockY;
247 case 6:
248 return gpu::Processor::BlockZ;
249 default:
250 return gpu::Processor::Sequential;
251 }
252}
253
254static GPUParallelDimAttr intToParDim(MLIRContext *context, int64_t dimInt) {
255 return GPUParallelDimAttr::get(
256 context, IntegerAttr::get(IndexType::get(context), dimInt));
257}
258
259static GPUParallelDimAttr processorParDim(MLIRContext *context,
260 gpu::Processor proc) {
261 return GPUParallelDimAttr::get(
262 context,
263 IntegerAttr::get(IndexType::get(context), gpuProcessorIndex(proc)));
264}
265
266static ParseResult parseProcessorValue(AsmParser &parser,
267 GPUParallelDimAttr &dim) {
268 std::string keyword;
269 llvm::SMLoc loc = parser.getCurrentLocation();
270 if (failed(parser.parseKeywordOrString(&keyword)))
271 return failure();
272 auto maybeProcessor = gpu::symbolizeProcessor(keyword);
273 if (!maybeProcessor)
274 return parser.emitError(loc)
275 << "expected one of ::mlir::gpu::Processor enum names";
276 dim = intToParDim(parser.getContext(), gpuProcessorIndex(*maybeProcessor));
277 return success();
278}
279
280static void printProcessorValue(AsmPrinter &printer,
281 const GPUParallelDimAttr &attr) {
282 gpu::Processor processor = indexToGpuProcessor(attr.getValue().getInt());
283 printer << gpu::stringifyProcessor(processor);
284}
285
286} // namespace
287
288//===----------------------------------------------------------------------===//
289// KernelEnvironmentOp
290//===----------------------------------------------------------------------===//
291
292void KernelEnvironmentOp::getSuccessorRegions(
294 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
295 regions);
296}
297
298ValueRange KernelEnvironmentOp::getSuccessorInputs(RegionSuccessor successor) {
299 return getSingleRegionSuccessorInputs(getOperation(), successor);
300}
301
302void KernelEnvironmentOp::getCanonicalizationPatterns(
303 RewritePatternSet &results, MLIRContext *context) {
304 results.add<RemoveEmptyKernelEnvironment>(context);
305}
306
307template <typename ComputeConstructT>
308KernelEnvironmentOp
309KernelEnvironmentOp::createAndPopulate(ComputeConstructT computeConstruct,
310 OpBuilder &builder) {
311 auto kernelEnvironment = KernelEnvironmentOp::create(
312 builder, computeConstruct->getLoc(),
313 computeConstruct.getDataClauseOperands(),
314 computeConstruct.getAsyncOperands(),
315 computeConstruct.getAsyncOperandsDeviceTypeAttr(),
316 computeConstruct.getAsyncOnlyAttr(), computeConstruct.getWaitOperands(),
317 computeConstruct.getWaitOperandsSegmentsAttr(),
318 computeConstruct.getWaitOperandsDeviceTypeAttr(),
319 computeConstruct.getHasWaitDevnumAttr(),
320 computeConstruct.getWaitOnlyAttr());
321 Block &block = kernelEnvironment.getRegion().emplaceBlock();
322 builder.setInsertionPointToStart(&block);
323 return kernelEnvironment;
324}
325
326template KernelEnvironmentOp
327KernelEnvironmentOp::createAndPopulate<ParallelOp>(ParallelOp, OpBuilder &);
328template KernelEnvironmentOp
329KernelEnvironmentOp::createAndPopulate<KernelsOp>(KernelsOp, OpBuilder &);
330template KernelEnvironmentOp
331KernelEnvironmentOp::createAndPopulate<SerialOp>(SerialOp, OpBuilder &);
332
333//===----------------------------------------------------------------------===//
334// FirstprivateMapInitialOp
335//===----------------------------------------------------------------------===//
336
337LogicalResult FirstprivateMapInitialOp::verify() {
338 if (getDataClause() != acc::DataClause::acc_firstprivate)
339 return emitError("data clause associated with firstprivate operation must "
340 "match its intent");
341 if (!getVar())
342 return emitError("must have var operand");
343 if (!mlir::isa<mlir::acc::PointerLikeType>(getVar().getType()) &&
344 !mlir::isa<mlir::acc::MappableType>(getVar().getType()))
345 return emitError("var must be mappable or pointer-like");
346 if (mlir::isa<mlir::acc::PointerLikeType>(getVar().getType()) &&
347 getVarType() == getVar().getType())
348 return emitError("varType must capture the element type of var");
349 if (getModifiers() != acc::DataClauseModifier::none)
350 return emitError("no data clause modifiers are allowed");
351 return success();
352}
353
354void FirstprivateMapInitialOp::getEffects(
356 &effects) {
357 effects.emplace_back(MemoryEffects::Read::get(),
359 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
361}
362
363//===----------------------------------------------------------------------===//
364// ReductionInitOp
365//===----------------------------------------------------------------------===//
366
367void ReductionInitOp::getSuccessorRegions(
369 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
370 regions);
371}
372
373void ReductionInitOp::getRegionInvocationBounds(
374 ArrayRef<Attribute> operands,
375 SmallVectorImpl<InvocationBounds> &invocationBounds) {
376 invocationBounds.emplace_back(1, 1);
377}
378
379ValueRange ReductionInitOp::getSuccessorInputs(RegionSuccessor successor) {
380 return getSingleRegionSuccessorInputs(getOperation(), successor);
381}
382
383LogicalResult ReductionInitOp::verify() {
384 Block &block = getRegion().front();
385 if (auto yieldOp = dyn_cast<acc::YieldOp>(block.getTerminator())) {
386 if (yieldOp.getNumOperands() != 1)
387 return emitOpError(
388 "region must yield exactly one value (private storage)");
389 if (yieldOp.getOperand(0).getType() != getVar().getType())
390 return emitOpError("yielded value type must match var type");
391 }
392 return success();
393}
394
395//===----------------------------------------------------------------------===//
396// ReductionCombineRegionOp
397//===----------------------------------------------------------------------===//
398
399void ReductionCombineRegionOp::getSuccessorRegions(
401 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
402 regions);
403}
404
405void ReductionCombineRegionOp::getRegionInvocationBounds(
406 ArrayRef<Attribute> operands,
407 SmallVectorImpl<InvocationBounds> &invocationBounds) {
408 invocationBounds.emplace_back(1, 1);
409}
410
412ReductionCombineRegionOp::getSuccessorInputs(RegionSuccessor successor) {
413 return getSingleRegionSuccessorInputs(getOperation(), successor);
414}
415
416LogicalResult ReductionCombineRegionOp::verify() {
417 Block &block = getRegion().front();
418 if (auto yieldOp = dyn_cast<acc::YieldOp>(block.getTerminator())) {
419 if (yieldOp.getNumOperands() != 0)
420 return emitOpError("region must be terminated by acc.yield with no "
421 "operands");
422 }
423 return success();
424}
425
426//===----------------------------------------------------------------------===//
427// ReductionCombineOp
428//===----------------------------------------------------------------------===//
429
430void ReductionCombineOp::getEffects(
432 &effects) {
433 effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemrefMutable(),
435 effects.emplace_back(MemoryEffects::Read::get(), &getDestMemrefMutable(),
437 effects.emplace_back(MemoryEffects::Write::get(), &getDestMemrefMutable(),
439}
440
441//===----------------------------------------------------------------------===//
442// ComputeRegionOp
443//===----------------------------------------------------------------------===//
444
445static ParWidthOp getParWidthOpForLaunchArg(ComputeRegionOp op,
446 GPUParallelDimAttr parDim) {
447 for (auto launchArg : op.getLaunchArgs()) {
448 auto parOp = launchArg.getDefiningOp<ParWidthOp>();
449 if (!parOp)
450 continue;
451 auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
452 if (launchArgDim == parDim)
453 return parOp;
454 }
455 return nullptr;
456}
457
458std::optional<Value> ComputeRegionOp::getLaunchArg(GPUParallelDimAttr parDim) {
459 if (auto parWidthOp = getParWidthOpForLaunchArg(*this, parDim))
460 return parWidthOp.getResult();
461 return {};
462}
463
464std::optional<Value>
465ComputeRegionOp::getKnownLaunchArg(GPUParallelDimAttr parDim) {
466 if (auto parWidthOp = getParWidthOpForLaunchArg(*this, parDim))
467 if (parWidthOp.getLaunchArg())
468 return parWidthOp.getLaunchArg();
469 return {};
470}
471
472std::optional<uint64_t>
473ComputeRegionOp::getKnownConstantLaunchArg(GPUParallelDimAttr parDim) {
474 auto knownParWidth = getKnownLaunchArg(parDim);
475 if (knownParWidth.has_value())
476 return getConstantIntValue(knownParWidth.value());
477 return {};
478}
479
480BlockArgument ComputeRegionOp::appendInputArg(Value value) {
481 getInputArgsMutable().append(value);
482 return getBody()->addArgument(value.getType(), getLoc());
483}
484
485std::optional<BlockArgument>
486ComputeRegionOp::wireHoistedValueThroughIns(Value value) {
487 Region &region = getRegion();
488
489 auto useIsInRegion = [&](OpOperand &use) -> bool {
490 return region.isAncestor(use.getOwner()->getParentRegion());
491 };
492
493 if (!areValuesDefinedAbove(ValueRange(value), region) ||
494 !llvm::any_of(value.getUses(), useIsInRegion))
495 return std::nullopt;
496
497 BlockArgument arg = appendInputArg(value);
498 replaceAllUsesInRegionWith(value, arg, region);
499 return arg;
500}
501
502bool ComputeRegionOp::isEffectivelySerial() {
503 auto *ctx = getContext();
504
505 if (getLaunchArg(GPUParallelDimAttr::seqDim(ctx)))
506 return true;
507
508 auto checkDim = [&](GPUParallelDimAttr dim) -> bool {
509 auto val = getKnownConstantLaunchArg(dim);
510 return val && *val == 1;
511 };
512
513 return checkDim(GPUParallelDimAttr::threadXDim(ctx)) &&
514 checkDim(GPUParallelDimAttr::threadYDim(ctx)) &&
515 checkDim(GPUParallelDimAttr::threadZDim(ctx)) &&
516 checkDim(GPUParallelDimAttr::blockXDim(ctx)) &&
517 checkDim(GPUParallelDimAttr::blockYDim(ctx)) &&
518 checkDim(GPUParallelDimAttr::blockZDim(ctx));
519}
520
521BlockArgument ComputeRegionOp::parDimToWidth(GPUParallelDimAttr parDim) {
522 for (auto [pos, launchArg] : llvm::enumerate(getLaunchArgs())) {
523 auto parOp = launchArg.getDefiningOp<ParWidthOp>();
524 assert(parOp);
525 auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
526 if (launchArgDim == parDim) {
527 assert(pos < getRegion().front().getNumArguments() &&
528 "launch arg position out of range");
529 return getRegion().front().getArgument(pos);
530 }
531 }
532 llvm_unreachable("attempting to get unspecified parDim");
533}
534
535SmallVector<GPUParallelDimAttr> ComputeRegionOp::getLaunchParDims() {
537 for (auto launchArg : getLaunchArgs()) {
538 auto parOp = launchArg.getDefiningOp<ParWidthOp>();
539 auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
540 int64_t dimInt = launchArgDim.getValue().getInt();
541 parDims.push_back(intToParDim(getContext(), dimInt));
542 }
543 return parDims;
544}
545
546Value ComputeRegionOp::getOperand(BlockArgument blockArg) {
547 Block *body = getBody();
548 if (blockArg.getOwner() != body)
549 return Value();
550 unsigned argNumber = blockArg.getArgNumber();
551 unsigned numLaunchArgs = getLaunchArgs().size();
552 unsigned numInputArgs = getInputArgs().size();
553 if (argNumber >= numLaunchArgs + numInputArgs)
554 return Value();
555 if (argNumber < numLaunchArgs)
556 return getLaunchArgs()[argNumber];
557 return getInputArgs()[argNumber - numLaunchArgs];
558}
559
560std::optional<BlockArgument> ComputeRegionOp::getBlockArg(Value value) {
561 Block *body = getBody();
562 for (auto [idx, launchVal] : llvm::enumerate(getLaunchArgs())) {
563 if (launchVal == value)
564 return body->getArgument(idx);
565 }
566 unsigned numLaunch = getLaunchArgs().size();
567 for (auto [idx, inputVal] : llvm::enumerate(getInputArgs())) {
568 if (inputVal == value)
569 return body->getArgument(numLaunch + idx);
570 }
571 return std::nullopt;
572}
573
574void ComputeRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
575 MLIRContext *context) {
576 results.add<ComputeRegionRemoveDuplicateArgs, ComputeRegionRemoveUnusedArgs>(
577 context);
578}
579
580BlockArgument ComputeRegionOp::gpuParWidth(gpu::Processor processor) {
581 return parDimToWidth(GPUParallelDimAttr::get(getContext(), processor));
582}
583
584LogicalResult ComputeRegionOp::verify() {
585 for (auto op : getLaunchArgs())
586 if (!op.getDefiningOp<acc::ParWidthOp>())
587 return emitOpError(
588 "launch arguments must be results of acc.par_width operations");
589
590 unsigned expectedBlockArgs = getLaunchArgs().size() + getInputArgs().size();
591 unsigned actualBlockArgs = getRegion().front().getNumArguments();
592 if (expectedBlockArgs != actualBlockArgs)
593 return emitOpError("expected ")
594 << expectedBlockArgs << " block arguments (launch + input), got "
595 << actualBlockArgs;
596
597 return success();
598}
599
600void ComputeRegionOp::print(OpAsmPrinter &p) {
601 ValueRange regionArgs = getBody()->getArguments();
602 ValueRange launchArgs = getLaunchArgs();
603 ValueRange inputArgs = getInputArgs();
604
605 assert(regionArgs.size() == (launchArgs.size() + inputArgs.size()) &&
606 "region args mismatch");
607
608 if (getStream())
609 p << " stream(" << getStream() << " : " << getStream().getType() << ")";
610
611 size_t i = 0;
612 if (!launchArgs.empty()) {
613 p << " launch(";
614 for (size_t j = 0; j < launchArgs.size(); ++j, ++i) {
615 p << regionArgs[i] << " = " << launchArgs[j];
616 if (j < launchArgs.size() - 1)
617 p << ", ";
618 }
619 p << ")";
620 }
621 if (!inputArgs.empty()) {
622 p << " ins(";
623 for (size_t j = 0; j < inputArgs.size(); ++j, ++i) {
624 p << regionArgs[i] << " = " << inputArgs[j];
625 if (j < inputArgs.size() - 1)
626 p << ", ";
627 }
628 p << ") : (";
629 for (size_t j = 0; j < inputArgs.size(); ++j) {
630 p << inputArgs[j].getType();
631 if (j < inputArgs.size() - 1)
632 p << ", ";
633 }
634 p << ")";
635 }
636 p.printOptionalArrowTypeList(getResultTypes());
637 p << " ";
638 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
639 p.printOptionalAttrDict((*this)->getAttrs(),
640 /*elidedAttrs=*/getOperandSegmentSizeAttr());
641}
642
643ParseResult ComputeRegionOp::parse(OpAsmParser &parser,
645 auto &builder = parser.getBuilder();
646
648 OpAsmParser::UnresolvedOperand streamOperand;
649 Type streamType;
652 SmallVector<Type> types;
653
654 bool hasStream = false;
655 if (succeeded(parser.parseOptionalKeyword("stream"))) {
656 hasStream = true;
657 if (parser.parseLParen() || parser.parseOperand(streamOperand) ||
658 parser.parseColon() || parser.parseType(streamType) ||
659 parser.parseRParen())
660 return failure();
661 }
662
663 if (succeeded(parser.parseOptionalKeyword("launch"))) {
664 if (parser.parseAssignmentList(regionArgs, launchOperands))
665 return failure();
666 Type indexType = builder.getIndexType();
667 for (size_t i = 0; i < regionArgs.size(); ++i)
668 types.push_back(indexType);
669 }
670
671 if (succeeded(parser.parseOptionalKeyword("ins"))) {
672 if (parser.parseAssignmentList(regionArgs, inputOperands) ||
673 parser.parseColon() || parser.parseLParen() ||
674 parser.parseTypeList(types) || parser.parseRParen())
675 return failure();
676 }
677
678 if (parser.parseOptionalArrowTypeList(result.types))
679 return failure();
680
681 for (auto [iterArg, type] : llvm::zip_equal(regionArgs, types))
682 iterArg.type = type;
683
684 Region *body = result.addRegion();
685 if (parser.parseRegion(*body, regionArgs))
686 return failure();
687 ComputeRegionOp::ensureTerminator(*body, parser.getBuilder(),
688 result.location);
689
690 const size_t numLaunchOperands = launchOperands.size();
691 const size_t numInputOperands = inputOperands.size();
692 assert(numLaunchOperands + numInputOperands == regionArgs.size() &&
693 "compute region args mismatch");
694
695 result.addAttribute(
696 ComputeRegionOp::getOperandSegmentSizeAttr(),
697 builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLaunchOperands),
698 static_cast<int32_t>(numInputOperands),
699 hasStream ? 1 : 0}));
700
701 for (size_t i = 0; i < numLaunchOperands; ++i) {
702 if (parser.resolveOperand(launchOperands[i], types[i], result.operands))
703 return failure();
704 }
705
706 for (size_t i = numLaunchOperands; i < regionArgs.size(); ++i) {
707 if (parser.resolveOperand(inputOperands[i - numLaunchOperands], types[i],
708 result.operands))
709 return failure();
710 }
711
712 if (hasStream) {
713 if (parser.resolveOperand(streamOperand, streamType, result.operands))
714 return failure();
715 }
716
717 if (parser.parseOptionalAttrDict(result.attributes))
718 return failure();
719
720 return success();
721}
722
723//===----------------------------------------------------------------------===//
724// GPUParallelDimAttr
725//===----------------------------------------------------------------------===//
726
727GPUParallelDimAttr GPUParallelDimAttr::get(MLIRContext *context,
728 gpu::Processor proc) {
729 return processorParDim(context, proc);
730}
731
732GPUParallelDimAttr GPUParallelDimAttr::seqDim(MLIRContext *context) {
733 return processorParDim(context, gpu::Processor::Sequential);
734}
735
736GPUParallelDimAttr GPUParallelDimAttr::threadXDim(MLIRContext *context) {
737 return processorParDim(context, gpu::Processor::ThreadX);
738}
739
740GPUParallelDimAttr GPUParallelDimAttr::threadYDim(MLIRContext *context) {
741 return processorParDim(context, gpu::Processor::ThreadY);
742}
743
744GPUParallelDimAttr GPUParallelDimAttr::threadZDim(MLIRContext *context) {
745 return processorParDim(context, gpu::Processor::ThreadZ);
746}
747
748GPUParallelDimAttr GPUParallelDimAttr::blockXDim(MLIRContext *context) {
749 return processorParDim(context, gpu::Processor::BlockX);
750}
751
752GPUParallelDimAttr GPUParallelDimAttr::blockYDim(MLIRContext *context) {
753 return processorParDim(context, gpu::Processor::BlockY);
754}
755
756GPUParallelDimAttr GPUParallelDimAttr::blockZDim(MLIRContext *context) {
757 return processorParDim(context, gpu::Processor::BlockZ);
758}
759
760Attribute GPUParallelDimAttr::parse(AsmParser &parser, Type type) {
761 GPUParallelDimAttr dim;
762 if (parser.parseLess() || parseProcessorValue(parser, dim) ||
763 parser.parseGreater()) {
764 parser.emitError(parser.getCurrentLocation(),
765 "expected format `<` processor_name `>`");
766 return {};
767 }
768 return dim;
769}
770
771void GPUParallelDimAttr::print(AsmPrinter &printer) const {
772 printer << "<";
773 printProcessorValue(printer, *this);
774 printer << ">";
775}
776
777GPUParallelDimAttr GPUParallelDimAttr::threadDim(MLIRContext *context,
778 unsigned index) {
779 assert(index <= 2 && "thread dimension index must be 0, 1, or 2");
780 switch (index) {
781 case 0:
782 return threadXDim(context);
783 case 1:
784 return threadYDim(context);
785 case 2:
786 return threadZDim(context);
787 }
788 llvm_unreachable("validated thread dimension index");
789}
790
791GPUParallelDimAttr GPUParallelDimAttr::blockDim(MLIRContext *context,
792 unsigned index) {
793 assert(index <= 2 && "block dimension index must be 0, 1, or 2");
794 switch (index) {
795 case 0:
796 return blockXDim(context);
797 case 1:
798 return blockYDim(context);
799 case 2:
800 return blockZDim(context);
801 }
802 llvm_unreachable("validated block dimension index");
803}
804
805gpu::Processor GPUParallelDimAttr::getProcessor() const {
806 return indexToGpuProcessor(getValue().getInt());
807}
808
809int GPUParallelDimAttr::getOrder() const {
810 return gpuProcessorIndex(getProcessor());
811}
812
813GPUParallelDimAttr GPUParallelDimAttr::getOneHigher() const {
814 int order = getOrder();
815 if (order >= 6) // BlockZ is the highest
816 return *this;
817 return get(getContext(), indexToGpuProcessor(order + 1));
818}
819
820GPUParallelDimAttr GPUParallelDimAttr::getOneLower() const {
821 int order = getOrder();
822 if (order <= 0) // Sequential is the lowest
823 return *this;
824 return get(getContext(), indexToGpuProcessor(order - 1));
825}
826
827bool GPUParallelDimAttr::isSeq() const {
828 return getProcessor() == gpu::Processor::Sequential;
829}
830bool GPUParallelDimAttr::isThreadX() const {
831 return getProcessor() == gpu::Processor::ThreadX;
832}
833bool GPUParallelDimAttr::isThreadY() const {
834 return getProcessor() == gpu::Processor::ThreadY;
835}
836bool GPUParallelDimAttr::isThreadZ() const {
837 return getProcessor() == gpu::Processor::ThreadZ;
838}
839bool GPUParallelDimAttr::isBlockX() const {
840 return getProcessor() == gpu::Processor::BlockX;
841}
842bool GPUParallelDimAttr::isBlockY() const {
843 return getProcessor() == gpu::Processor::BlockY;
844}
845bool GPUParallelDimAttr::isBlockZ() const {
846 return getProcessor() == gpu::Processor::BlockZ;
847}
848bool GPUParallelDimAttr::isAnyThread() const {
849 return isThreadX() || isThreadY() || isThreadZ();
850}
851bool GPUParallelDimAttr::isAnyBlock() const {
852 return isBlockX() || isBlockY() || isBlockZ();
853}
854
855//===----------------------------------------------------------------------===//
856// GPUParallelDimsAttr
857//===----------------------------------------------------------------------===//
858
859GPUParallelDimsAttr GPUParallelDimsAttr::seq(MLIRContext *ctx) {
860 return GPUParallelDimsAttr::get(ctx, {GPUParallelDimAttr::seqDim(ctx)});
861}
862
863bool GPUParallelDimsAttr::isSeq() const {
864 assert(!getArray().empty() && "no par_dims found");
865 if (getArray().size() == 1) {
866 auto parDim = dyn_cast<GPUParallelDimAttr>(getArray()[0]);
867 assert(parDim && "expected GPUParallelDimAttr");
868 return parDim.isSeq();
869 }
870 return false;
871}
872
873bool GPUParallelDimsAttr::isParallel() const { return !isSeq(); }
874
875bool GPUParallelDimsAttr::isMultiDim() const { return getArray().size() > 1; }
876
877bool GPUParallelDimsAttr::hasAnyBlockLevel() const {
878 return llvm::any_of(
879 getArray(), [](const GPUParallelDimAttr &p) { return p.isAnyBlock(); });
880}
881
882bool GPUParallelDimsAttr::hasOnlyBlockLevel() const {
883 return !getArray().empty() &&
884 llvm::all_of(getArray(), [](const GPUParallelDimAttr &p) {
885 return p.isAnyBlock();
886 });
887}
888
889bool GPUParallelDimsAttr::hasOnlyThreadYLevel() const {
890 return !getArray().empty() &&
891 llvm::all_of(getArray(), [](const GPUParallelDimAttr &p) {
892 return p.isThreadY();
893 });
894}
895
896bool GPUParallelDimsAttr::hasOnlyThreadXLevel() const {
897 return !getArray().empty() &&
898 llvm::all_of(getArray(), [](const GPUParallelDimAttr &p) {
899 return p.isThreadX();
900 });
901}
902
903Attribute GPUParallelDimsAttr::parse(AsmParser &parser, Type type) {
904 auto delimiter = AsmParser::Delimiter::Square;
906 auto parseParDim = [&]() -> ParseResult {
907 GPUParallelDimAttr dim;
908 if (parseProcessorValue(parser, dim))
909 return failure();
910 parDims.push_back(dim);
911 return success();
912 };
913 if (parser.parseCommaSeparatedList(delimiter, parseParDim,
914 "list of OpenACC GPU parallel dimensions"))
915 return {};
916 return GPUParallelDimsAttr::get(parser.getContext(), parDims);
917}
918
919void GPUParallelDimsAttr::print(AsmPrinter &printer) const {
920 printer << "[";
921 llvm::interleaveComma(getArray(), printer,
922 [&printer](const GPUParallelDimAttr &p) {
923 printProcessorValue(printer, p);
924 });
925 printer << "]";
926}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void addOperandEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, MutableOperandRange operand)
Helper to add an effect on an operand, referenced by its mutable range.
Definition OpenACC.cpp:1295
static void addResultEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, Value result)
Helper to add an effect on a result value.
Definition OpenACC.cpp:1305
static void getSingleRegionOpSuccessorRegions(Operation *op, Region &region, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > &regions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
Definition OpenACC.cpp:493
static ValueRange getSingleRegionSuccessorInputs(Operation *op, RegionSuccessor successor)
Definition OpenACC.cpp:504
b getContext())
static ParWidthOp getParWidthOpForLaunchArg(ComputeRegionOp op, GPUParallelDimAttr parDim)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseKeywordOrString(std::string *result)
Parse a keyword or a quoted string.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:318
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:315
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition Block.cpp:198
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IndexType getIndexType()
Definition Builders.cpp:55
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:119
unsigned size() const
Returns the current size of the range.
Definition ValueRange.h:157
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
This class represents an operand of an operation.
Definition Value.h:254
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
result_range getResults()
Definition Operation.h:441
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:233
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:5174
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:5143
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition OpenACC.cpp:5247
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition OpenACC.cpp:5151
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool areValuesDefinedAbove(Range values, Region &limit)
Check if all values in the provided range are defined above the limit region.
Definition RegionUtils.h:26
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.