MLIR 23.0.0git
OpenACCCG.cpp
Go to the documentation of this file.
1//===- OpenACCCG.cpp - OpenACC codegen ops, attributes, and types ---------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implementation for OpenACC codegen operations, attributes, and types.
10// These correspond to the definitions in OpenACCCG*.td tablegen files
11// and are kept in a separate file because they do not represent direct mappings
12// of OpenACC language constructs; they are intermediate representations used
13// when decomposing and lowering primary `acc` dialect operations.
14//
15//===----------------------------------------------------------------------===//
16
22#include "mlir/IR/Region.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallVector.h"
27
28using namespace mlir;
29using namespace acc;
30
31namespace {
32
33/// Generic helper for single-region OpenACC ops that execute their body once
34/// and then return to the parent operation with their results (if any).
35static void
39 if (point.isParent()) {
40 regions.push_back(RegionSuccessor(&region));
41 return;
42 }
43 regions.push_back(RegionSuccessor::parent());
44}
45
47 RegionSuccessor successor) {
48 return successor.isParent() ? ValueRange(op->getResults()) : ValueRange();
49}
50
51/// Remove empty acc.kernel_environment operations. If the operation has wait
52/// operands, create a acc.wait operation to preserve synchronization.
53struct RemoveEmptyKernelEnvironment
54 : public OpRewritePattern<acc::KernelEnvironmentOp> {
55 using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
56
57 LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
58 PatternRewriter &rewriter) const override {
59 assert(op->getNumRegions() == 1 && "expected op to have one region");
60
61 Block &block = op.getRegion().front();
62 if (!block.empty())
63 return failure();
64
65 // Conservatively disable canonicalization of empty acc.kernel_environment
66 // operations if the wait operands in the kernel_environment cannot be fully
67 // represented by acc.wait operation.
68
69 // Disable canonicalization if device type is not the default
70 if (auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) {
71 for (auto attr : deviceTypeAttr) {
72 if (auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
73 if (dtAttr.getValue() != mlir::acc::DeviceType::None)
74 return failure();
75 }
76 }
77 }
78
79 // Disable canonicalization if any wait segment has a devnum
80 if (auto hasDevnumAttr = op.getHasWaitDevnumAttr()) {
81 for (auto attr : hasDevnumAttr) {
82 if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
83 if (boolAttr.getValue())
84 return failure();
85 }
86 }
87 }
88
89 // Disable canonicalization if there are multiple wait segments
90 if (auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) {
91 if (segmentsAttr.size() > 1)
92 return failure();
93 }
94
95 // Remove empty kernel environment.
96 // Preserve synchronization by creating acc.wait operation if needed.
97 if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
98 rewriter.replaceOpWithNewOp<acc::WaitOp>(op, op.getWaitOperands(),
99 /*asyncOperand=*/Value(),
100 /*waitDevnum=*/Value(),
101 /*async=*/nullptr,
102 /*ifCond=*/Value());
103 else
104 rewriter.eraseOp(op);
105
106 return success();
107 }
108};
109
110template <typename EffectTy>
111static void addOperandEffect(
113 &effects,
114 const MutableOperandRange &operand) {
115 for (unsigned i = 0, e = operand.size(); i < e; ++i)
116 effects.emplace_back(EffectTy::get(), &operand[i]);
117}
118
119template <typename EffectTy>
120static void addResultEffect(
122 &effects,
123 Value result) {
124 effects.emplace_back(EffectTy::get(), mlir::cast<mlir::OpResult>(result));
125}
126
127static int64_t gpuProcessorIndex(gpu::Processor p) {
128 switch (p) {
129 case gpu::Processor::Sequential:
130 return 0;
131 case gpu::Processor::ThreadX:
132 return 1;
133 case gpu::Processor::ThreadY:
134 return 2;
135 case gpu::Processor::ThreadZ:
136 return 3;
137 case gpu::Processor::BlockX:
138 return 4;
139 case gpu::Processor::BlockY:
140 return 5;
141 case gpu::Processor::BlockZ:
142 return 6;
143 }
144 llvm_unreachable("unhandled gpu::Processor");
145}
146
147static gpu::Processor indexToGpuProcessor(int64_t idx) {
148 switch (idx) {
149 case 0:
150 return gpu::Processor::Sequential;
151 case 1:
152 return gpu::Processor::ThreadX;
153 case 2:
154 return gpu::Processor::ThreadY;
155 case 3:
156 return gpu::Processor::ThreadZ;
157 case 4:
158 return gpu::Processor::BlockX;
159 case 5:
160 return gpu::Processor::BlockY;
161 case 6:
162 return gpu::Processor::BlockZ;
163 default:
164 return gpu::Processor::Sequential;
165 }
166}
167
168static GPUParallelDimAttr intToParDim(MLIRContext *context, int64_t dimInt) {
169 return GPUParallelDimAttr::get(
170 context, IntegerAttr::get(IndexType::get(context), dimInt));
171}
172
173static GPUParallelDimAttr processorParDim(MLIRContext *context,
174 gpu::Processor proc) {
175 return GPUParallelDimAttr::get(
176 context,
177 IntegerAttr::get(IndexType::get(context), gpuProcessorIndex(proc)));
178}
179
180static ParseResult parseProcessorValue(AsmParser &parser,
181 GPUParallelDimAttr &dim) {
182 std::string keyword;
183 llvm::SMLoc loc = parser.getCurrentLocation();
184 if (failed(parser.parseKeywordOrString(&keyword)))
185 return failure();
186 auto maybeProcessor = gpu::symbolizeProcessor(keyword);
187 if (!maybeProcessor)
188 return parser.emitError(loc)
189 << "expected one of ::mlir::gpu::Processor enum names";
190 dim = intToParDim(parser.getContext(), gpuProcessorIndex(*maybeProcessor));
191 return success();
192}
193
194static void printProcessorValue(AsmPrinter &printer,
195 const GPUParallelDimAttr &attr) {
196 gpu::Processor processor = indexToGpuProcessor(attr.getValue().getInt());
197 printer << gpu::stringifyProcessor(processor);
198}
199
200} // namespace
201
202//===----------------------------------------------------------------------===//
203// KernelEnvironmentOp
204//===----------------------------------------------------------------------===//
205
206void KernelEnvironmentOp::getSuccessorRegions(
208 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
209 regions);
210}
211
212ValueRange KernelEnvironmentOp::getSuccessorInputs(RegionSuccessor successor) {
213 return getSingleRegionSuccessorInputs(getOperation(), successor);
214}
215
216void KernelEnvironmentOp::getCanonicalizationPatterns(
217 RewritePatternSet &results, MLIRContext *context) {
218 results.add<RemoveEmptyKernelEnvironment>(context);
219}
220
221template <typename ComputeConstructT>
222KernelEnvironmentOp
223KernelEnvironmentOp::createAndPopulate(ComputeConstructT computeConstruct,
224 OpBuilder &builder) {
225 auto kernelEnvironment = KernelEnvironmentOp::create(
226 builder, computeConstruct->getLoc(),
227 computeConstruct.getDataClauseOperands(),
228 computeConstruct.getAsyncOperands(),
229 computeConstruct.getAsyncOperandsDeviceTypeAttr(),
230 computeConstruct.getAsyncOnlyAttr(), computeConstruct.getWaitOperands(),
231 computeConstruct.getWaitOperandsSegmentsAttr(),
232 computeConstruct.getWaitOperandsDeviceTypeAttr(),
233 computeConstruct.getHasWaitDevnumAttr(),
234 computeConstruct.getWaitOnlyAttr());
235 Block &block = kernelEnvironment.getRegion().emplaceBlock();
236 builder.setInsertionPointToStart(&block);
237 return kernelEnvironment;
238}
239
240template KernelEnvironmentOp
241KernelEnvironmentOp::createAndPopulate<ParallelOp>(ParallelOp, OpBuilder &);
242template KernelEnvironmentOp
243KernelEnvironmentOp::createAndPopulate<KernelsOp>(KernelsOp, OpBuilder &);
244template KernelEnvironmentOp
245KernelEnvironmentOp::createAndPopulate<SerialOp>(SerialOp, OpBuilder &);
246
247//===----------------------------------------------------------------------===//
248// FirstprivateMapInitialOp
249//===----------------------------------------------------------------------===//
250
251LogicalResult FirstprivateMapInitialOp::verify() {
252 if (getDataClause() != acc::DataClause::acc_firstprivate)
253 return emitError("data clause associated with firstprivate operation must "
254 "match its intent");
255 if (!getVar())
256 return emitError("must have var operand");
257 if (!mlir::isa<mlir::acc::PointerLikeType>(getVar().getType()) &&
258 !mlir::isa<mlir::acc::MappableType>(getVar().getType()))
259 return emitError("var must be mappable or pointer-like");
260 if (mlir::isa<mlir::acc::PointerLikeType>(getVar().getType()) &&
261 getVarType() == getVar().getType())
262 return emitError("varType must capture the element type of var");
263 if (getModifiers() != acc::DataClauseModifier::none)
264 return emitError("no data clause modifiers are allowed");
265 return success();
266}
267
268void FirstprivateMapInitialOp::getEffects(
270 &effects) {
271 effects.emplace_back(MemoryEffects::Read::get(),
273 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
275}
276
277//===----------------------------------------------------------------------===//
278// ReductionInitOp
279//===----------------------------------------------------------------------===//
280
281void ReductionInitOp::getSuccessorRegions(
283 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
284 regions);
285}
286
287void ReductionInitOp::getRegionInvocationBounds(
288 ArrayRef<Attribute> operands,
289 SmallVectorImpl<InvocationBounds> &invocationBounds) {
290 invocationBounds.emplace_back(1, 1);
291}
292
293ValueRange ReductionInitOp::getSuccessorInputs(RegionSuccessor successor) {
294 return getSingleRegionSuccessorInputs(getOperation(), successor);
295}
296
297LogicalResult ReductionInitOp::verify() {
298 Block &block = getRegion().front();
299 if (auto yieldOp = dyn_cast<acc::YieldOp>(block.getTerminator())) {
300 if (yieldOp.getNumOperands() != 1)
301 return emitOpError(
302 "region must yield exactly one value (private storage)");
303 if (yieldOp.getOperand(0).getType() != getVar().getType())
304 return emitOpError("yielded value type must match var type");
305 }
306 return success();
307}
308
309//===----------------------------------------------------------------------===//
310// ReductionCombineRegionOp
311//===----------------------------------------------------------------------===//
312
313void ReductionCombineRegionOp::getSuccessorRegions(
315 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
316 regions);
317}
318
319void ReductionCombineRegionOp::getRegionInvocationBounds(
320 ArrayRef<Attribute> operands,
321 SmallVectorImpl<InvocationBounds> &invocationBounds) {
322 invocationBounds.emplace_back(1, 1);
323}
324
326ReductionCombineRegionOp::getSuccessorInputs(RegionSuccessor successor) {
327 return getSingleRegionSuccessorInputs(getOperation(), successor);
328}
329
330LogicalResult ReductionCombineRegionOp::verify() {
331 Block &block = getRegion().front();
332 if (auto yieldOp = dyn_cast<acc::YieldOp>(block.getTerminator())) {
333 if (yieldOp.getNumOperands() != 0)
334 return emitOpError("region must be terminated by acc.yield with no "
335 "operands");
336 }
337 return success();
338}
339
340//===----------------------------------------------------------------------===//
341// ReductionCombineOp
342//===----------------------------------------------------------------------===//
343
344void ReductionCombineOp::getEffects(
346 &effects) {
347 effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemrefMutable(),
349 effects.emplace_back(MemoryEffects::Read::get(), &getDestMemrefMutable(),
351 effects.emplace_back(MemoryEffects::Write::get(), &getDestMemrefMutable(),
353}
354
355//===----------------------------------------------------------------------===//
356// ComputeRegionOp
357//===----------------------------------------------------------------------===//
358
359static ParWidthOp getParWidthOpForLaunchArg(ComputeRegionOp op,
360 GPUParallelDimAttr parDim) {
361 for (auto launchArg : op.getLaunchArgs()) {
362 auto parOp = launchArg.getDefiningOp<ParWidthOp>();
363 if (!parOp)
364 continue;
365 auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
366 if (launchArgDim == parDim)
367 return parOp;
368 }
369 return nullptr;
370}
371
372std::optional<Value> ComputeRegionOp::getLaunchArg(GPUParallelDimAttr parDim) {
373 if (auto parWidthOp = getParWidthOpForLaunchArg(*this, parDim))
374 return parWidthOp.getResult();
375 return {};
376}
377
378std::optional<Value>
379ComputeRegionOp::getKnownLaunchArg(GPUParallelDimAttr parDim) {
380 if (auto parWidthOp = getParWidthOpForLaunchArg(*this, parDim))
381 if (parWidthOp.getLaunchArg())
382 return parWidthOp.getLaunchArg();
383 return {};
384}
385
386std::optional<uint64_t>
387ComputeRegionOp::getKnownConstantLaunchArg(GPUParallelDimAttr parDim) {
388 auto knownParWidth = getKnownLaunchArg(parDim);
389 if (knownParWidth.has_value())
390 return getConstantIntValue(knownParWidth.value());
391 return {};
392}
393
394BlockArgument ComputeRegionOp::appendInputArg(Value value) {
395 getInputArgsMutable().append(value);
396 return getBody()->addArgument(value.getType(), getLoc());
397}
398
399bool ComputeRegionOp::isEffectivelySerial() {
400 auto *ctx = getContext();
401
402 if (getLaunchArg(GPUParallelDimAttr::seqDim(ctx)))
403 return true;
404
405 auto checkDim = [&](GPUParallelDimAttr dim) -> bool {
406 auto val = getKnownConstantLaunchArg(dim);
407 return val && *val == 1;
408 };
409
410 return checkDim(GPUParallelDimAttr::threadXDim(ctx)) &&
411 checkDim(GPUParallelDimAttr::threadYDim(ctx)) &&
412 checkDim(GPUParallelDimAttr::threadZDim(ctx)) &&
413 checkDim(GPUParallelDimAttr::blockXDim(ctx)) &&
414 checkDim(GPUParallelDimAttr::blockYDim(ctx)) &&
415 checkDim(GPUParallelDimAttr::blockZDim(ctx));
416}
417
418BlockArgument ComputeRegionOp::parDimToWidth(GPUParallelDimAttr parDim) {
419 for (auto [pos, launchArg] : llvm::enumerate(getLaunchArgs())) {
420 auto parOp = launchArg.getDefiningOp<ParWidthOp>();
421 assert(parOp);
422 auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
423 if (launchArgDim == parDim) {
424 assert(pos < getRegion().front().getNumArguments() &&
425 "launch arg position out of range");
426 return getRegion().front().getArgument(pos);
427 }
428 }
429 llvm_unreachable("attempting to get unspecified parDim");
430}
431
432SmallVector<GPUParallelDimAttr> ComputeRegionOp::getLaunchParDims() {
434 for (auto launchArg : getLaunchArgs()) {
435 auto parOp = launchArg.getDefiningOp<ParWidthOp>();
436 auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
437 int64_t dimInt = launchArgDim.getValue().getInt();
438 parDims.push_back(intToParDim(getContext(), dimInt));
439 }
440 return parDims;
441}
442
443Value ComputeRegionOp::getOperand(BlockArgument blockArg) {
444 unsigned argNumber = blockArg.getArgNumber();
445 unsigned numLaunchArgs = getLaunchArgs().size();
446 assert(argNumber < (numLaunchArgs + getInputArgs().size()) &&
447 "invalid block argument");
448 if (argNumber < numLaunchArgs)
449 return getLaunchArgs()[argNumber];
450 return getInputArgs()[argNumber - numLaunchArgs];
451}
452
453BlockArgument ComputeRegionOp::gpuParWidth(gpu::Processor processor) {
454 return parDimToWidth(GPUParallelDimAttr::get(getContext(), processor));
455}
456
457LogicalResult ComputeRegionOp::verify() {
458 for (auto op : getLaunchArgs())
459 if (!op.getDefiningOp<acc::ParWidthOp>())
460 return emitOpError(
461 "launch arguments must be results of acc.par_width operations");
462
463 unsigned expectedBlockArgs = getLaunchArgs().size() + getInputArgs().size();
464 unsigned actualBlockArgs = getRegion().front().getNumArguments();
465 if (expectedBlockArgs != actualBlockArgs)
466 return emitOpError("expected ")
467 << expectedBlockArgs << " block arguments (launch + input), got "
468 << actualBlockArgs;
469
470 return success();
471}
472
473void ComputeRegionOp::print(OpAsmPrinter &p) {
474 ValueRange regionArgs = getBody()->getArguments();
475 ValueRange launchArgs = getLaunchArgs();
476 ValueRange inputArgs = getInputArgs();
477
478 assert(regionArgs.size() == (launchArgs.size() + inputArgs.size()) &&
479 "region args mismatch");
480
481 if (getStream())
482 p << " stream(" << getStream() << " : " << getStream().getType() << ")";
483
484 size_t i = 0;
485 if (!launchArgs.empty()) {
486 p << " launch(";
487 for (size_t j = 0; j < launchArgs.size(); ++j, ++i) {
488 p << regionArgs[i] << " = " << launchArgs[j];
489 if (j < launchArgs.size() - 1)
490 p << ", ";
491 }
492 p << ")";
493 }
494 if (!inputArgs.empty()) {
495 p << " ins(";
496 for (size_t j = 0; j < inputArgs.size(); ++j, ++i) {
497 p << regionArgs[i] << " = " << inputArgs[j];
498 if (j < inputArgs.size() - 1)
499 p << ", ";
500 }
501 p << ") : (";
502 for (size_t j = 0; j < inputArgs.size(); ++j) {
503 p << inputArgs[j].getType();
504 if (j < inputArgs.size() - 1)
505 p << ", ";
506 }
507 p << ")";
508 }
509 p.printOptionalArrowTypeList(getResultTypes());
510 p << " ";
511 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
512 p.printOptionalAttrDict((*this)->getAttrs(),
513 /*elidedAttrs=*/getOperandSegmentSizeAttr());
514}
515
516ParseResult ComputeRegionOp::parse(OpAsmParser &parser,
518 auto &builder = parser.getBuilder();
519
521 OpAsmParser::UnresolvedOperand streamOperand;
522 Type streamType;
525 SmallVector<Type> types;
526
527 bool hasStream = false;
528 if (succeeded(parser.parseOptionalKeyword("stream"))) {
529 hasStream = true;
530 if (parser.parseLParen() || parser.parseOperand(streamOperand) ||
531 parser.parseColon() || parser.parseType(streamType) ||
532 parser.parseRParen())
533 return failure();
534 }
535
536 if (succeeded(parser.parseOptionalKeyword("launch"))) {
537 if (parser.parseAssignmentList(regionArgs, launchOperands))
538 return failure();
539 Type indexType = builder.getIndexType();
540 for (size_t i = 0; i < regionArgs.size(); ++i)
541 types.push_back(indexType);
542 }
543
544 if (succeeded(parser.parseOptionalKeyword("ins"))) {
545 if (parser.parseAssignmentList(regionArgs, inputOperands) ||
546 parser.parseColon() || parser.parseLParen() ||
547 parser.parseTypeList(types) || parser.parseRParen())
548 return failure();
549 }
550
551 if (parser.parseOptionalArrowTypeList(result.types))
552 return failure();
553
554 for (auto [iterArg, type] : llvm::zip_equal(regionArgs, types))
555 iterArg.type = type;
556
557 Region *body = result.addRegion();
558 if (parser.parseRegion(*body, regionArgs))
559 return failure();
560
561 const size_t numLaunchOperands = launchOperands.size();
562 const size_t numInputOperands = inputOperands.size();
563 assert(numLaunchOperands + numInputOperands == regionArgs.size() &&
564 "compute region args mismatch");
565
566 result.addAttribute(
567 ComputeRegionOp::getOperandSegmentSizeAttr(),
568 builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLaunchOperands),
569 static_cast<int32_t>(numInputOperands),
570 hasStream ? 1 : 0}));
571
572 for (size_t i = 0; i < numLaunchOperands; ++i) {
573 if (parser.resolveOperand(launchOperands[i], types[i], result.operands))
574 return failure();
575 }
576
577 for (size_t i = numLaunchOperands; i < regionArgs.size(); ++i) {
578 if (parser.resolveOperand(inputOperands[i - numLaunchOperands], types[i],
579 result.operands))
580 return failure();
581 }
582
583 if (hasStream) {
584 if (parser.resolveOperand(streamOperand, streamType, result.operands))
585 return failure();
586 }
587
588 if (parser.parseOptionalAttrDict(result.attributes))
589 return failure();
590
591 return success();
592}
593
594//===----------------------------------------------------------------------===//
595// GPUParallelDimAttr
596//===----------------------------------------------------------------------===//
597
598GPUParallelDimAttr GPUParallelDimAttr::get(MLIRContext *context,
599 gpu::Processor proc) {
600 return processorParDim(context, proc);
601}
602
603GPUParallelDimAttr GPUParallelDimAttr::seqDim(MLIRContext *context) {
604 return processorParDim(context, gpu::Processor::Sequential);
605}
606
607GPUParallelDimAttr GPUParallelDimAttr::threadXDim(MLIRContext *context) {
608 return processorParDim(context, gpu::Processor::ThreadX);
609}
610
611GPUParallelDimAttr GPUParallelDimAttr::threadYDim(MLIRContext *context) {
612 return processorParDim(context, gpu::Processor::ThreadY);
613}
614
615GPUParallelDimAttr GPUParallelDimAttr::threadZDim(MLIRContext *context) {
616 return processorParDim(context, gpu::Processor::ThreadZ);
617}
618
619GPUParallelDimAttr GPUParallelDimAttr::blockXDim(MLIRContext *context) {
620 return processorParDim(context, gpu::Processor::BlockX);
621}
622
623GPUParallelDimAttr GPUParallelDimAttr::blockYDim(MLIRContext *context) {
624 return processorParDim(context, gpu::Processor::BlockY);
625}
626
627GPUParallelDimAttr GPUParallelDimAttr::blockZDim(MLIRContext *context) {
628 return processorParDim(context, gpu::Processor::BlockZ);
629}
630
631Attribute GPUParallelDimAttr::parse(AsmParser &parser, Type type) {
632 GPUParallelDimAttr dim;
633 if (parser.parseLess() || parseProcessorValue(parser, dim) ||
634 parser.parseGreater()) {
635 parser.emitError(parser.getCurrentLocation(),
636 "expected format `<` processor_name `>`");
637 return {};
638 }
639 return dim;
640}
641
642void GPUParallelDimAttr::print(AsmPrinter &printer) const {
643 printer << "<";
644 printProcessorValue(printer, *this);
645 printer << ">";
646}
647
648GPUParallelDimAttr GPUParallelDimAttr::threadDim(MLIRContext *context,
649 unsigned index) {
650 assert(index <= 2 && "thread dimension index must be 0, 1, or 2");
651 switch (index) {
652 case 0:
653 return threadXDim(context);
654 case 1:
655 return threadYDim(context);
656 case 2:
657 return threadZDim(context);
658 }
659 llvm_unreachable("validated thread dimension index");
660}
661
662GPUParallelDimAttr GPUParallelDimAttr::blockDim(MLIRContext *context,
663 unsigned index) {
664 assert(index <= 2 && "block dimension index must be 0, 1, or 2");
665 switch (index) {
666 case 0:
667 return blockXDim(context);
668 case 1:
669 return blockYDim(context);
670 case 2:
671 return blockZDim(context);
672 }
673 llvm_unreachable("validated block dimension index");
674}
675
676gpu::Processor GPUParallelDimAttr::getProcessor() const {
677 return indexToGpuProcessor(getValue().getInt());
678}
679
680int GPUParallelDimAttr::getOrder() const {
681 return gpuProcessorIndex(getProcessor());
682}
683
684GPUParallelDimAttr GPUParallelDimAttr::getOneHigher() const {
685 int order = getOrder();
686 if (order >= 6) // BlockZ is the highest
687 return *this;
688 return get(getContext(), indexToGpuProcessor(order + 1));
689}
690
691GPUParallelDimAttr GPUParallelDimAttr::getOneLower() const {
692 int order = getOrder();
693 if (order <= 0) // Sequential is the lowest
694 return *this;
695 return get(getContext(), indexToGpuProcessor(order - 1));
696}
697
698bool GPUParallelDimAttr::isSeq() const {
699 return getProcessor() == gpu::Processor::Sequential;
700}
701bool GPUParallelDimAttr::isThreadX() const {
702 return getProcessor() == gpu::Processor::ThreadX;
703}
704bool GPUParallelDimAttr::isThreadY() const {
705 return getProcessor() == gpu::Processor::ThreadY;
706}
707bool GPUParallelDimAttr::isThreadZ() const {
708 return getProcessor() == gpu::Processor::ThreadZ;
709}
710bool GPUParallelDimAttr::isBlockX() const {
711 return getProcessor() == gpu::Processor::BlockX;
712}
713bool GPUParallelDimAttr::isBlockY() const {
714 return getProcessor() == gpu::Processor::BlockY;
715}
716bool GPUParallelDimAttr::isBlockZ() const {
717 return getProcessor() == gpu::Processor::BlockZ;
718}
719bool GPUParallelDimAttr::isAnyThread() const {
720 return isThreadX() || isThreadY() || isThreadZ();
721}
722bool GPUParallelDimAttr::isAnyBlock() const {
723 return isBlockX() || isBlockY() || isBlockZ();
724}
725
726//===----------------------------------------------------------------------===//
727// GPUParallelDimsAttr
728//===----------------------------------------------------------------------===//
729
730GPUParallelDimsAttr GPUParallelDimsAttr::seq(MLIRContext *ctx) {
731 return GPUParallelDimsAttr::get(ctx, {GPUParallelDimAttr::seqDim(ctx)});
732}
733
734bool GPUParallelDimsAttr::isSeq() const {
735 assert(!getArray().empty() && "no par_dims found");
736 if (getArray().size() == 1) {
737 auto parDim = dyn_cast<GPUParallelDimAttr>(getArray()[0]);
738 assert(parDim && "expected GPUParallelDimAttr");
739 return parDim.isSeq();
740 }
741 return false;
742}
743
744bool GPUParallelDimsAttr::isParallel() const { return !isSeq(); }
745
746bool GPUParallelDimsAttr::isMultiDim() const { return getArray().size() > 1; }
747
748bool GPUParallelDimsAttr::hasAnyBlockLevel() const {
749 return llvm::any_of(
750 getArray(), [](const GPUParallelDimAttr &p) { return p.isAnyBlock(); });
751}
752
753bool GPUParallelDimsAttr::hasOnlyBlockLevel() const {
754 return !getArray().empty() &&
755 llvm::all_of(getArray(), [](const GPUParallelDimAttr &p) {
756 return p.isAnyBlock();
757 });
758}
759
760bool GPUParallelDimsAttr::hasOnlyThreadYLevel() const {
761 return !getArray().empty() &&
762 llvm::all_of(getArray(), [](const GPUParallelDimAttr &p) {
763 return p.isThreadY();
764 });
765}
766
767bool GPUParallelDimsAttr::hasOnlyThreadXLevel() const {
768 return !getArray().empty() &&
769 llvm::all_of(getArray(), [](const GPUParallelDimAttr &p) {
770 return p.isThreadX();
771 });
772}
773
774Attribute GPUParallelDimsAttr::parse(AsmParser &parser, Type type) {
775 auto delimiter = AsmParser::Delimiter::Square;
777 auto parseParDim = [&]() -> ParseResult {
778 GPUParallelDimAttr dim;
779 if (parseProcessorValue(parser, dim))
780 return failure();
781 parDims.push_back(dim);
782 return success();
783 };
784 if (parser.parseCommaSeparatedList(delimiter, parseParDim,
785 "list of OpenACC GPU parallel dimensions"))
786 return {};
787 return GPUParallelDimsAttr::get(parser.getContext(), parDims);
788}
789
790void GPUParallelDimsAttr::print(AsmPrinter &printer) const {
791 printer << "[";
792 llvm::interleaveComma(getArray(), printer,
793 [&printer](const GPUParallelDimAttr &p) {
794 printProcessorValue(printer, p);
795 });
796 printer << "]";
797}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void addOperandEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, MutableOperandRange operand)
Helper to add an effect on an operand, referenced by its mutable range.
Definition OpenACC.cpp:1224
static void addResultEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, Value result)
Helper to add an effect on a result value.
Definition OpenACC.cpp:1234
static void getSingleRegionOpSuccessorRegions(Operation *op, Region &region, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > &regions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
Definition OpenACC.cpp:422
static ValueRange getSingleRegionSuccessorInputs(Operation *op, RegionSuccessor successor)
Definition OpenACC.cpp:433
b getContext())
static ParWidthOp getParWidthOpForLaunchArg(ComputeRegionOp op, GPUParallelDimAttr parDim)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseKeywordOrString(std::string *result)
Parse a keyword or a quoted string.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:318
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IndexType getIndexType()
Definition Builders.cpp:55
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:119
unsigned size() const
Returns the current size of the range.
Definition ValueRange.h:157
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
result_range getResults()
Definition Operation.h:441
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:5098
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:5067
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition OpenACC.cpp:5171
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition OpenACC.cpp:5075
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.