MLIR 23.0.0git
GPUDialect.cpp
Go to the documentation of this file.
1//===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the GPU kernel-related dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
20#include "mlir/IR/Attributes.h"
21#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinOps.h"
25#include "mlir/IR/Diagnostics.h"
27#include "mlir/IR/Matchers.h"
30#include "mlir/IR/SymbolTable.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include "llvm/Support/CommandLine.h"
39#include "llvm/Support/ErrorHandling.h"
40#include "llvm/Support/FormatVariadic.h"
41#include "llvm/Support/InterleavedRange.h"
42#include "llvm/Support/StringSaver.h"
43#include <cassert>
44#include <numeric>
45
46using namespace mlir;
47using namespace mlir::gpu;
48
49#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
50
51//===----------------------------------------------------------------------===//
52// GPU Device Mapping Attributes
53//===----------------------------------------------------------------------===//
54
55int64_t GPUBlockMappingAttr::getMappingId() const {
56 return static_cast<int64_t>(getBlock());
57}
58
59bool GPUBlockMappingAttr::isLinearMapping() const {
60 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
61}
62
63int64_t GPUBlockMappingAttr::getRelativeIndex() const {
64 return isLinearMapping()
65 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
66 : getMappingId();
67}
68
69int64_t GPUWarpgroupMappingAttr::getMappingId() const {
70 return static_cast<int64_t>(getWarpgroup());
71}
72
73bool GPUWarpgroupMappingAttr::isLinearMapping() const {
74 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
75}
76
77int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
78 return isLinearMapping()
79 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
80 : getMappingId();
81}
82
83int64_t GPUWarpMappingAttr::getMappingId() const {
84 return static_cast<int64_t>(getWarp());
85}
86
87bool GPUWarpMappingAttr::isLinearMapping() const {
88 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
89}
90
91int64_t GPUWarpMappingAttr::getRelativeIndex() const {
92 return isLinearMapping()
93 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
94 : getMappingId();
95}
96
97int64_t GPUThreadMappingAttr::getMappingId() const {
98 return static_cast<int64_t>(getThread());
99}
100
101bool GPUThreadMappingAttr::isLinearMapping() const {
102 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
103}
104
105int64_t GPUThreadMappingAttr::getRelativeIndex() const {
106 return isLinearMapping()
107 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
108 : getMappingId();
109}
110
111int64_t GPULaneMappingAttr::getMappingId() const {
112 return static_cast<int64_t>(getLane());
113}
114
115bool GPULaneMappingAttr::isLinearMapping() const {
116 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
117}
118
119int64_t GPULaneMappingAttr::getRelativeIndex() const {
120 return isLinearMapping()
121 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
122 : getMappingId();
123}
124
125int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds() const { return 64; }
126
127/// 8 4 0
128/// Example mask : 0 0 0 1 1 0 1 0 0
129///
130/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
131/// Logical id for e.g. 5 (2) constructs filter (1 << 5 - 1).
132///
133/// Example mask : 0 0 0 1 1 0 1 0 0
134/// Example filter: 0 0 0 0 1 1 1 1 1
135/// Intersection : 0 0 0 0 1 0 1 0 0
136/// PopCnt : 2
137Value GPUMappingMaskAttr::createLogicalLinearMappingId(
138 OpBuilder &b, Value physicalLinearMappingId) const {
139 Location loc = physicalLinearMappingId.getLoc();
140 Value mask =
141 arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
142 Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
143 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
144 filter = arith::SubIOp::create(b, loc, filter, one);
145 Value filteredId = arith::AndIOp::create(b, loc, mask, filter);
146 return math::CtPopOp::create(b, loc, filteredId);
147}
148
149/// 8 4 0
150/// Example mask : 0 0 0 1 1 0 1 0 0
151///
152/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
153/// Logical id for e.g. 5 (2) constructs filter (1 << 5).
154///
155/// Example mask : 0 0 0 1 1 0 1 0 0
156/// Example filter: 0 0 0 1 0 0 0 0 0
157/// Intersection : 0 0 0 1 0 0 0 0 0
158/// Cmp : 1
159Value GPUMappingMaskAttr::createIsActiveIdPredicate(
160 OpBuilder &b, Value physicalLinearMappingId) const {
161 Location loc = physicalLinearMappingId.getLoc();
162 Value mask =
163 arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
164 Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
165 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
166 Value filtered = arith::AndIOp::create(b, loc, mask, filter);
167 Value zero = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(0));
168 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ne, filtered,
169 zero);
170}
171
172int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
173 return static_cast<int64_t>(getAddressSpace());
174}
175
176bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
177 llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
178}
179
180int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
181 llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
182}
183
184//===----------------------------------------------------------------------===//
185// MMAMatrixType
186//===----------------------------------------------------------------------===//
187
189 StringRef operand) {
190 return Base::get(elementType.getContext(), shape, elementType, operand);
191}
192
195 ArrayRef<int64_t> shape, Type elementType,
196 StringRef operand) {
197 return Base::getChecked(emitError, elementType.getContext(), shape,
198 elementType, operand);
199}
200
201unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
202
204 return getImpl()->getShape();
205}
206
207Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
208
209StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
210
212 return elementType.isF16() || elementType.isF32() || elementType.isF64() ||
213 elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
214 elementType.isInteger(32);
215}
216
217LogicalResult
219 ArrayRef<int64_t> shape, Type elementType,
220 StringRef operand) {
221 if (operand != "AOp" && operand != "BOp" && operand != "COp")
222 return emitError() << "operand expected to be one of AOp, BOp or COp";
223
224 if (shape.size() != 2)
225 return emitError() << "MMAMatrixType must have exactly two dimensions";
226
227 if (!MMAMatrixType::isValidElementType(elementType))
228 return emitError()
229 << "MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
230
231 return success();
232}
233
234//===----------------------------------------------------------------------===//
235// GPUDialect
236//===----------------------------------------------------------------------===//
237
238bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
239 if (!memorySpace)
240 return false;
241 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
242 return gpuAttr.getValue() == getWorkgroupAddressSpace();
243 return false;
244}
245
246bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
247 Attribute memorySpace = type.getMemorySpace();
248 return isWorkgroupMemoryAddressSpace(memorySpace);
249}
250
251bool GPUDialect::isKernel(Operation *op) {
252 UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
253 return static_cast<bool>(isKernelAttr);
254}
255
256namespace {
257/// This class defines the interface for handling inlining with gpu
258/// operations.
259struct GPUInlinerInterface : public DialectInlinerInterface {
260 using DialectInlinerInterface::DialectInlinerInterface;
261
262 /// All gpu dialect ops can be inlined.
263 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
264 return true;
265 }
266};
267} // namespace
268
269void GPUDialect::initialize() {
270 addTypes<AsyncTokenType>();
271 addTypes<MMAMatrixType>();
272 addTypes<SparseDnTensorHandleType>();
273 addTypes<SparseSpMatHandleType>();
274 addTypes<SparseSpGEMMOpHandleType>();
275 addOperations<
276#define GET_OP_LIST
277#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
278 >();
279 addAttributes<
280#define GET_ATTRDEF_LIST
281#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
282 >();
283 addInterfaces<GPUInlinerInterface>();
284 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
285 TerminatorOp>();
286 declarePromisedInterfaces<ValueBoundsOpInterface, ClusterDimOp,
287 ClusterDimBlocksOp, ClusterIdOp, ClusterBlockIdOp,
288 BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp,
289 LaneIdOp, SubgroupIdOp, GlobalIdOp, NumSubgroupsOp,
290 SubgroupSizeOp, LaunchOp, SubgroupBroadcastOp>();
291}
292
293static std::string getSparseHandleKeyword(SparseHandleKind kind) {
294 switch (kind) {
296 return "sparse.dntensor_handle";
298 return "sparse.spmat_handle";
300 return "sparse.spgemmop_handle";
301 }
302 llvm_unreachable("unknown sparse handle kind");
303 return "";
304}
305
306Type GPUDialect::parseType(DialectAsmParser &parser) const {
307 // Parse the main keyword for the type.
308 StringRef keyword;
309 if (parser.parseKeyword(&keyword))
310 return Type();
311 MLIRContext *context = getContext();
312
313 // Handle 'async token' types.
314 if (keyword == "async.token")
315 return AsyncTokenType::get(context);
316
317 if (keyword == "mma_matrix") {
318 SMLoc beginLoc = parser.getNameLoc();
319
320 // Parse '<'.
321 if (parser.parseLess())
322 return nullptr;
323
324 // Parse the size and elementType.
325 SmallVector<int64_t> shape;
326 Type elementType;
327 if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
328 parser.parseType(elementType))
329 return nullptr;
330
331 // Parse ','
332 if (parser.parseComma())
333 return nullptr;
334
335 // Parse operand.
336 std::string operand;
337 if (failed(parser.parseOptionalString(&operand)))
338 return nullptr;
339
340 // Parse '>'.
341 if (parser.parseGreater())
342 return nullptr;
343
345 parser.getEncodedSourceLoc(beginLoc)),
346 shape, elementType, operand);
347 }
348
350 return SparseDnTensorHandleType::get(context);
352 return SparseSpMatHandleType::get(context);
354 return SparseSpGEMMOpHandleType::get(context);
355
356 parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
357 return Type();
358}
359// TODO: print refined type here. Notice that should be corresponding to the
360// parser
361void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
362 TypeSwitch<Type>(type)
363 .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
364 .Case<SparseDnTensorHandleType>([&](Type) {
366 })
367 .Case<SparseSpMatHandleType>(
369 .Case<SparseSpGEMMOpHandleType>([&](Type) {
371 })
372 .Case([&](MMAMatrixType fragTy) {
373 os << "mma_matrix<";
374 auto shape = fragTy.getShape();
375 for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
376 os << *dim << 'x';
377 os << shape.back() << 'x' << fragTy.getElementType();
378 os << ", \"" << fragTy.getOperand() << "\"" << '>';
379 })
380 .DefaultUnreachable("unexpected 'gpu' type kind");
381}
382
383static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
384 NamedAttribute attr) {
385 auto array = dyn_cast<DenseI32ArrayAttr>(attr.getValue());
386 if (!array)
387 return op->emitOpError(Twine(attr.getName()) +
388 " must be a dense i32 array");
389 if (array.size() != 3)
390 return op->emitOpError(Twine(attr.getName()) +
391 " must contain exactly 3 elements");
392 return success();
393}
394
395LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
396 NamedAttribute attr) {
397 if (attr.getName() == getKnownBlockSizeAttrHelper().getName())
398 return verifyKnownLaunchSizeAttr(op, attr);
399 if (attr.getName() == getKnownGridSizeAttrHelper().getName())
400 return verifyKnownLaunchSizeAttr(op, attr);
401 if (attr.getName() == getKnownClusterSizeAttrHelper().getName())
402 return verifyKnownLaunchSizeAttr(op, attr);
403 if (!llvm::isa<UnitAttr>(attr.getValue()) ||
404 attr.getName() != getContainerModuleAttrName())
405 return success();
406
407 auto module = dyn_cast<ModuleOp>(op);
408 if (!module)
409 return op->emitError("expected '")
410 << getContainerModuleAttrName() << "' attribute to be attached to '"
411 << ModuleOp::getOperationName() << '\'';
412 return success();
413}
414
415/// Parses an optional list of async operands with an optional leading keyword.
416/// (`async`)? (`[` ssa-id-list `]`)?
417///
418/// This method is used by the tablegen assembly format for async ops as well.
419static ParseResult parseAsyncDependencies(
420 OpAsmParser &parser, Type &asyncTokenType,
422 auto loc = parser.getCurrentLocation();
423 if (succeeded(parser.parseOptionalKeyword("async"))) {
424 if (parser.getNumResults() == 0)
425 return parser.emitError(loc, "needs to be named when marked 'async'");
426 asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
427 }
428 return parser.parseOperandList(asyncDependencies,
430}
431
432/// Prints optional async dependencies with its leading keyword.
433/// (`async`)? (`[` ssa-id-list `]`)?
434// Used by the tablegen assembly format for several async ops.
436 Type asyncTokenType,
437 OperandRange asyncDependencies) {
438 if (asyncTokenType)
439 printer << "async";
440 if (asyncDependencies.empty())
441 return;
442 if (asyncTokenType)
443 printer << ' ';
444 printer << llvm::interleaved_array(asyncDependencies);
445}
446
447// GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
448/// Parses a GPU function memory attribution.
449///
450/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
451/// (`private` `(` ssa-id-and-type-list `)`)?
452///
453/// Note that this function parses only one of the two similar parts, with the
454/// keyword provided as argument.
455static ParseResult
456parseAttributions(OpAsmParser &parser, StringRef keyword,
458 // If we could not parse the keyword, just assume empty list and succeed.
459 if (failed(parser.parseOptionalKeyword(keyword)))
460 return success();
461
463 /*allowType=*/true);
464}
465
466static void printAttributions(OpAsmPrinter &p, StringRef keyword,
468 ArrayAttr attributes = {}) {
469 if (values.empty())
470 return;
471
472 p << ' ' << keyword << '(';
473 llvm::interleaveComma(
474 llvm::enumerate(values), p, [&p, attributes](auto pair) {
475 BlockArgument v = pair.value();
476 p << v << " : " << v.getType();
477
478 size_t attributionIndex = pair.index();
479 DictionaryAttr attrs;
480 if (attributes && attributionIndex < attributes.size())
481 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
482 if (attrs)
483 p.printOptionalAttrDict(attrs.getValue());
484 });
485 p << ')';
486}
487
488/// Verifies a GPU function memory attribution.
489static LogicalResult verifyAttributions(Operation *op,
490 ArrayRef<BlockArgument> attributions,
491 gpu::AddressSpace memorySpace) {
492 for (Value v : attributions) {
493 auto type = llvm::dyn_cast<MemRefType>(v.getType());
494 if (!type)
495 return op->emitOpError() << "expected memref type in attribution";
496
497 // We can only verify the address space if it hasn't already been lowered
498 // from the AddressSpaceAttr to a target-specific numeric value.
499 auto addressSpace =
500 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
501 if (!addressSpace)
502 continue;
503 if (addressSpace.getValue() != memorySpace)
504 return op->emitOpError()
505 << "expected memory space " << stringifyAddressSpace(memorySpace)
506 << " in attribution";
507 }
508 return success();
509}
510
511//===----------------------------------------------------------------------===//
512// AllReduceOp
513//===----------------------------------------------------------------------===//
514
515static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
516 Type resType) {
517 using Kind = gpu::AllReduceOperation;
518 if (llvm::is_contained(
519 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
520 opName)) {
521 if (!isa<FloatType>(resType))
522 return failure();
523 }
524
525 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
526 Kind::AND, Kind::OR, Kind::XOR},
527 opName)) {
528 if (!isa<IntegerType>(resType))
529 return failure();
530 }
531
532 return success();
533}
534
535LogicalResult gpu::AllReduceOp::verifyRegions() {
536 if (getBody().empty() != getOp().has_value())
537 return emitError("expected either an op attribute or a non-empty body");
538 if (!getBody().empty()) {
539 if (getBody().getNumArguments() != 2)
540 return emitError("expected two region arguments");
541 for (auto argument : getBody().getArguments()) {
542 if (argument.getType() != getType())
543 return emitError("incorrect region argument type");
544 }
545 unsigned yieldCount = 0;
546 for (Block &block : getBody()) {
547 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
548 if (yield.getNumOperands() != 1)
549 return emitError("expected one gpu.yield operand");
550 if (yield.getOperand(0).getType() != getType())
551 return emitError("incorrect gpu.yield type");
552 ++yieldCount;
553 }
554 }
555 if (yieldCount == 0)
556 return emitError("expected gpu.yield op in region");
557 } else {
558 gpu::AllReduceOperation opName = *getOp();
559 if (failed(verifyReduceOpAndType(opName, getType()))) {
560 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
561 << "` reduction operation is not compatible with type "
562 << getType();
563 }
564 }
565
566 return success();
567}
568
570 auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
571 if (!launchOp)
572 return false;
573
574 Region &body = launchOp.getBody();
575 assert(!body.empty() && "Invalid region");
576
577 // Only convert ops in gpu::launch entry block for now.
578 return op->getBlock() == &body.front();
579}
580
581OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
582 if (!getUniform() && canMakeGroupOpUniform(*this)) {
583 setUniform(true);
584 return getResult();
585 }
586
587 return nullptr;
588}
589
590// TODO: Support optional custom attributes (without dialect prefix).
591static ParseResult parseAllReduceOperation(AsmParser &parser,
592 AllReduceOperationAttr &attr) {
593 StringRef enumStr;
594 if (!parser.parseOptionalKeyword(&enumStr)) {
595 std::optional<AllReduceOperation> op =
596 gpu::symbolizeAllReduceOperation(enumStr);
597 if (!op)
598 return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
599 attr = AllReduceOperationAttr::get(parser.getContext(), *op);
600 }
601 return success();
602}
603
605 AllReduceOperationAttr attr) {
606 if (attr)
607 attr.print(printer);
608}
609
610//===----------------------------------------------------------------------===//
611// SubgroupReduceOp
612//===----------------------------------------------------------------------===//
613
614LogicalResult gpu::SubgroupReduceOp::verify() {
615 Type elemType = getType();
616 if (auto vecTy = dyn_cast<VectorType>(elemType)) {
617 if (vecTy.isScalable())
618 return emitOpError() << "is not compatible with scalable vector types";
619
620 elemType = vecTy.getElementType();
621 }
622
623 gpu::AllReduceOperation opName = getOp();
624 if (failed(verifyReduceOpAndType(opName, elemType))) {
625 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
626 << "` reduction operation is not compatible with type "
627 << getType();
628 }
629
630 auto clusterSize = getClusterSize();
631 if (clusterSize) {
632 uint32_t size = *clusterSize;
633 if (!llvm::isPowerOf2_32(size)) {
634 return emitOpError() << "cluster size " << size
635 << " is not a power of two";
636 }
637 }
638
639 uint32_t stride = getClusterStride();
640 if (stride != 1 && !clusterSize) {
641 return emitOpError() << "cluster stride can only be specified if cluster "
642 "size is specified";
643 }
644 if (!llvm::isPowerOf2_32(stride)) {
645 return emitOpError() << "cluster stride " << stride
646 << " is not a power of two";
647 }
648
649 return success();
650}
651
652OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
653 if (getClusterSize() == 1)
654 return getValue();
655
656 if (!getUniform() && canMakeGroupOpUniform(*this)) {
657 setUniform(true);
658 return getResult();
659 }
660
661 return nullptr;
662}
663
664//===----------------------------------------------------------------------===//
665// AsyncOpInterface
666//===----------------------------------------------------------------------===//
667
669 op->insertOperands(0, {token});
670 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
671 return;
672 auto attrName =
674 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
675
676 // Async dependencies is the only variadic operand.
677 if (!sizeAttr)
678 return;
679
680 SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
681 ++sizes.front();
682 op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
683}
684
685//===----------------------------------------------------------------------===//
686// LaunchOp
687//===----------------------------------------------------------------------===//
688
689void LaunchOp::build(OpBuilder &builder, OperationState &result,
690 Value gridSizeX, Value gridSizeY, Value gridSizeZ,
691 Value getBlockSizeX, Value getBlockSizeY,
692 Value getBlockSizeZ, Value dynamicSharedMemorySize,
693 Type asyncTokenType, ValueRange asyncDependencies,
694 TypeRange workgroupAttributions,
695 TypeRange privateAttributions, Value clusterSizeX,
696 Value clusterSizeY, Value clusterSizeZ,
697 FlatSymbolRefAttr module, FlatSymbolRefAttr function) {
698 OpBuilder::InsertionGuard g(builder);
699
700 // Add a WorkGroup attribution attribute. This attribute is required to
701 // identify private attributions in the list of block argguments.
702 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
703 builder.getI64IntegerAttr(workgroupAttributions.size()));
704
705 // Add Op operands.
706 result.addOperands(asyncDependencies);
707 if (asyncTokenType)
708 result.types.push_back(builder.getType<AsyncTokenType>());
709
710 // Add grid and block sizes as op operands, followed by the data operands.
711 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
712 getBlockSizeY, getBlockSizeZ});
713 if (clusterSizeX)
714 result.addOperands(clusterSizeX);
715 if (clusterSizeY)
716 result.addOperands(clusterSizeY);
717 if (clusterSizeZ)
718 result.addOperands(clusterSizeZ);
719 if (dynamicSharedMemorySize)
720 result.addOperands(dynamicSharedMemorySize);
721
722 // Add optional module and function attributes.
723 if (module)
724 result.addAttribute(getModuleAttrName(result.name), module);
725 if (function)
726 result.addAttribute(getFunctionAttrName(result.name), function);
727
728 // Create a kernel body region with kNumConfigRegionAttributes + N memory
729 // attributions, where the first kNumConfigRegionAttributes arguments have
730 // `index` type and the rest have the same types as the data operands.
731 Region *kernelRegion = result.addRegion();
732 Block *body = builder.createBlock(kernelRegion);
733 // TODO: Allow passing in proper locations here.
734 for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
735 body->addArgument(builder.getIndexType(), result.location);
736 // Add WorkGroup & Private attributions to the region arguments.
737 for (Type argTy : workgroupAttributions)
738 body->addArgument(argTy, result.location);
739 for (Type argTy : privateAttributions)
740 body->addArgument(argTy, result.location);
741 // Fill OperandSegmentSize Attribute.
742 SmallVector<int32_t, 11> segmentSizes(11, 1);
743 segmentSizes.front() = asyncDependencies.size();
744 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
745 segmentSizes[7] = clusterSizeX ? 1 : 0;
746 segmentSizes[8] = clusterSizeY ? 1 : 0;
747 segmentSizes[9] = clusterSizeZ ? 1 : 0;
748 result.addAttribute(getOperandSegmentSizeAttr(),
749 builder.getDenseI32ArrayAttr(segmentSizes));
750}
751
752KernelDim3 LaunchOp::getBlockIds() {
753 assert(!getBody().empty() && "LaunchOp body must not be empty.");
754 auto args = getBody().getArguments();
755 return KernelDim3{args[0], args[1], args[2]};
756}
757
758KernelDim3 LaunchOp::getThreadIds() {
759 assert(!getBody().empty() && "LaunchOp body must not be empty.");
760 auto args = getBody().getArguments();
761 return KernelDim3{args[3], args[4], args[5]};
762}
763
764KernelDim3 LaunchOp::getGridSize() {
765 assert(!getBody().empty() && "LaunchOp body must not be empty.");
766 auto args = getBody().getArguments();
767 return KernelDim3{args[6], args[7], args[8]};
768}
769
770KernelDim3 LaunchOp::getBlockSize() {
771 assert(!getBody().empty() && "LaunchOp body must not be empty.");
772 auto args = getBody().getArguments();
773 return KernelDim3{args[9], args[10], args[11]};
774}
775
776std::optional<KernelDim3> LaunchOp::getClusterIds() {
777 assert(!getBody().empty() && "LaunchOp body must not be empty.");
778 if (!hasClusterSize())
779 return std::nullopt;
780 auto args = getBody().getArguments();
781 return KernelDim3{args[12], args[13], args[14]};
782}
783
784std::optional<KernelDim3> LaunchOp::getClusterSize() {
785 assert(!getBody().empty() && "LaunchOp body must not be empty.");
786 if (!hasClusterSize())
787 return std::nullopt;
788 auto args = getBody().getArguments();
789 return KernelDim3{args[15], args[16], args[17]};
790}
791
792KernelDim3 LaunchOp::getGridSizeOperandValues() {
793 auto operands = getOperands().drop_front(getAsyncDependencies().size());
794 return KernelDim3{operands[0], operands[1], operands[2]};
795}
796
797KernelDim3 LaunchOp::getBlockSizeOperandValues() {
798 auto operands = getOperands().drop_front(getAsyncDependencies().size());
799 return KernelDim3{operands[3], operands[4], operands[5]};
800}
801
802std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
803 auto operands = getOperands().drop_front(getAsyncDependencies().size());
804 if (!hasClusterSize())
805 return std::nullopt;
806 return KernelDim3{operands[6], operands[7], operands[8]};
807}
808
809LogicalResult LaunchOp::verify() {
810 if (!(hasClusterSize()) &&
811 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
812 return emitOpError() << "cluster size must be all present";
813 return success();
814}
815
816LogicalResult LaunchOp::verifyRegions() {
817 // Kernel launch takes kNumConfigOperands leading operands for grid/block
818 // sizes and transforms them into kNumConfigRegionAttributes region arguments
819 // for block/thread identifiers and grid/block sizes.
820 if (getBody().empty()) {
821 return emitOpError("body region is empty");
822 }
823 if (getBody().getNumArguments() <
824 kNumConfigRegionAttributes + getNumWorkgroupAttributions()) {
825 return emitOpError("unexpected number of region arguments");
826 }
827
828 // Verify Attributions Address Spaces.
829 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
830 GPUDialect::getWorkgroupAddressSpace())) ||
831 failed(verifyAttributions(getOperation(), getPrivateAttributions(),
832 GPUDialect::getPrivateAddressSpace())))
833 return failure();
834
835 // Block terminators without successors are expected to exit the kernel region
836 // and must be `gpu.terminator`.
837 for (Block &block : getBody()) {
838 if (block.empty())
839 continue;
840 if (block.back().getNumSuccessors() != 0)
841 continue;
842 if (!isa<gpu::TerminatorOp>(&block.back())) {
843 return block.back()
844 .emitError()
845 .append("expected '", gpu::TerminatorOp::getOperationName(),
846 "' or a terminator with successors")
847 .attachNote(getLoc())
848 .append("in '", LaunchOp::getOperationName(), "' body region");
849 }
850 }
851
852 if (getNumResults() == 0 && getAsyncToken())
853 return emitOpError("needs to be named when async keyword is specified");
854
855 return success();
856}
857
858// Pretty-print the kernel grid/block size assignment as
859// (%iter-x, %iter-y, %iter-z) in
860// (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
861// where %size-* and %iter-* will correspond to the body region arguments.
863 KernelDim3 operands, KernelDim3 ids) {
864 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
865 p << size.x << " = " << operands.x << ", ";
866 p << size.y << " = " << operands.y << ", ";
867 p << size.z << " = " << operands.z << ')';
868}
869
870void LaunchOp::print(OpAsmPrinter &p) {
871 if (getAsyncToken()) {
872 p << " async";
873 if (!getAsyncDependencies().empty())
874 p << " [" << getAsyncDependencies() << ']';
875 }
876 // Print the launch configuration.
877 if (hasClusterSize()) {
878 p << ' ' << getClustersKeyword();
879 printSizeAssignment(p, getClusterSize().value(),
880 getClusterSizeOperandValues().value(),
881 getClusterIds().value());
882 }
883 p << ' ' << getBlocksKeyword();
884 printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
885 getBlockIds());
886 p << ' ' << getThreadsKeyword();
887 printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
888 getThreadIds());
889 if (getDynamicSharedMemorySize())
890 p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
891 << getDynamicSharedMemorySize();
892
893 // Print optional module attribute.
894 StringRef moduleAttrName = getModuleAttrName();
895 if (auto module = getModule()) {
896 p << ' ' << moduleAttrName << '(';
897 p.printSymbolName(*module);
898 p << ')';
899 }
900 // Print optional function attribute.
901 StringRef functionAttrName = getFunctionAttrName();
902 if (auto function = getFunction()) {
903 p << ' ' << functionAttrName << '(';
904 p.printSymbolName(*function);
905 p << ')';
906 }
907
908 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
909 printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
910
911 p << ' ';
912
913 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
914 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
915 LaunchOp::getOperandSegmentSizeAttr(),
916 getNumWorkgroupAttributionsAttrName(),
917 moduleAttrName, functionAttrName});
918}
919
920// Parse the size assignment blocks for blocks and threads. These have the form
921// (%region_arg, %region_arg, %region_arg) in
922// (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
923// where %region_arg are percent-identifiers for the region arguments to be
924// introduced further (SSA defs), and %operand are percent-identifiers for the
925// SSA value uses.
926static ParseResult
931 StringRef keyword) {
932 assert(indices.size() == 3 && "space for three indices expected");
935 /*allowResultNumber=*/false) ||
936 parser.parseKeyword("in") || parser.parseLParen())
937 return failure();
938
939 if (args.size() != 3) {
940 return parser.emitError(parser.getNameLoc())
941 << keyword << " expects 3 arguments, but got " << args.size();
942 }
943 std::move(args.begin(), args.end(), indices.begin());
944
945 for (int i = 0; i < 3; ++i) {
946 if (i != 0 && parser.parseComma())
947 return failure();
948 if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
949 parser.parseEqual() || parser.parseOperand(sizes[i]))
950 return failure();
951 }
952
953 return parser.parseRParen();
954}
955
956/// Parses a Launch operation.
957/// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
958/// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
959/// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
960/// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
961/// (`dynamic_shared_memory_size` ssa-use)?
962/// (`module(` symbol-ref-id `)`)?
963/// (`function(` symbol-ref-id `)`)?
964/// memory-attribution
965/// region attr-dict?
966/// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
967ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
968 // Sizes of the grid and block.
969 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
970 sizes(LaunchOp::kNumConfigOperands);
971
972 // Region arguments to be created.
973 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
974 LaunchOp::kNumConfigRegionAttributes);
975
976 // Parse optional async dependencies.
977 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
978 Type asyncTokenType;
979 if (failed(
980 parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
981 parser.resolveOperands(asyncDependencies, asyncTokenType,
982 result.operands))
983 return failure();
984 if (parser.getNumResults() > 0) {
985 if (!asyncTokenType)
986 return parser.emitError(
987 parser.getNameLoc(),
988 "gpu.launch requires 'async' keyword to return a value");
989 result.types.push_back(asyncTokenType);
990 }
991
992 bool hasCluster = false;
993 if (succeeded(parser.parseOptionalKeyword(LaunchOp::getClustersKeyword()))) {
994 hasCluster = true;
995 sizes.resize(9);
996 regionArgs.resize(18);
997 }
998 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
999 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
1000
1001 // Last three segment assigns the cluster size. In the region argument
1002 // list, this is last 6 arguments.
1003 if (hasCluster) {
1005 parser, sizesRef.drop_front(6), regionArgsRef.slice(15, 3),
1006 regionArgsRef.slice(12, 3), LaunchOp::getClustersKeyword()))
1007 return failure();
1008 }
1009 // Parse the size assignment segments: the first segment assigns grid sizes
1010 // and defines values for block identifiers; the second segment assigns block
1011 // sizes and defines values for thread identifiers. In the region argument
1012 // list, identifiers precede sizes, and block-related values precede
1013 // thread-related values.
1014 if (parser.parseKeyword(LaunchOp::getBlocksKeyword()) ||
1015 parseSizeAssignment(parser, sizesRef.take_front(3),
1016 regionArgsRef.slice(6, 3), regionArgsRef.slice(0, 3),
1017 LaunchOp::getBlocksKeyword()) ||
1018 parser.parseKeyword(LaunchOp::getThreadsKeyword()) ||
1019 parseSizeAssignment(parser, sizesRef.drop_front(3),
1020 regionArgsRef.slice(9, 3), regionArgsRef.slice(3, 3),
1021 LaunchOp::getThreadsKeyword()) ||
1022 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
1023 result.operands))
1024 return failure();
1025
1026 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
1027 bool hasDynamicSharedMemorySize = false;
1028 if (!parser.parseOptionalKeyword(
1029 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1030 hasDynamicSharedMemorySize = true;
1031 if (parser.parseOperand(dynamicSharedMemorySize) ||
1032 parser.resolveOperand(dynamicSharedMemorySize,
1033 parser.getBuilder().getI32Type(),
1034 result.operands))
1035 return failure();
1036 }
1037
1038 // Parse optional module attribute.
1039 StringRef moduleAttrName = getModuleAttrName(result.name);
1040 if (succeeded(parser.parseOptionalKeyword(moduleAttrName))) {
1041 FlatSymbolRefAttr moduleSymbol;
1042 if (parser.parseLParen() ||
1043 parser.parseAttribute(moduleSymbol, Type(), moduleAttrName,
1044 result.attributes) ||
1045 parser.parseRParen())
1046 return failure();
1047 }
1048 // Parse optional function attribute.
1049 StringRef functionAttrName = getFunctionAttrName(result.name);
1050 if (succeeded(parser.parseOptionalKeyword(functionAttrName))) {
1051 FlatSymbolRefAttr funcSymbol;
1052 if (parser.parseLParen() ||
1053 parser.parseAttribute(funcSymbol, Type(), functionAttrName,
1054 result.attributes) ||
1055 parser.parseRParen())
1056 return failure();
1057 }
1058
1059 // Create the region arguments, it has kNumConfigRegionAttributes arguments
1060 // that correspond to block/thread identifiers and grid/block sizes, all
1061 // having `index` type, a variadic number of WorkGroup Attributions and
1062 // a variadic number of Private Attributions. The number of WorkGroup
1063 // Attributions is stored in the attr with name:
1064 // LaunchOp::getNumWorkgroupAttributionsAttrName().
1065 Type index = parser.getBuilder().getIndexType();
1066 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
1067 LaunchOp::kNumConfigRegionAttributes + 6, index);
1068
1069 SmallVector<OpAsmParser::Argument> regionArguments;
1070 for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1071 OpAsmParser::Argument arg;
1072 arg.ssaName = std::get<0>(ssaValueAndType);
1073 arg.type = std::get<1>(ssaValueAndType);
1074 regionArguments.push_back(arg);
1075 }
1076
1077 Builder &builder = parser.getBuilder();
1078 // Parse workgroup memory attributions.
1079 if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
1080 regionArguments)))
1081 return failure();
1082
1083 // Store the number of operands we just parsed as the number of workgroup
1084 // memory attributions.
1085 unsigned numWorkgroupAttrs = regionArguments.size() -
1086 LaunchOp::kNumConfigRegionAttributes -
1087 (hasCluster ? 6 : 0);
1088 result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1089 builder.getI64IntegerAttr(numWorkgroupAttrs));
1090
1091 // Parse private memory attributions.
1092 if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
1093 regionArguments)))
1094 return failure();
1095
1096 // Introduce the body region and parse it. The region has
1097 // kNumConfigRegionAttributes arguments that correspond to
1098 // block/thread identifiers and grid/block sizes, all having `index` type.
1099 Region *body = result.addRegion();
1100 if (parser.parseRegion(*body, regionArguments) ||
1101 parser.parseOptionalAttrDict(result.attributes))
1102 return failure();
1103
1104 SmallVector<int32_t, 11> segmentSizes(11, 1);
1105 segmentSizes.front() = asyncDependencies.size();
1106
1107 if (!hasCluster) {
1108 segmentSizes[7] = 0;
1109 segmentSizes[8] = 0;
1110 segmentSizes[9] = 0;
1111 }
1112 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1113 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1114 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1115 return success();
1116}
1117
1118/// Simplify the gpu.launch when the range of a thread or block ID is
1119/// trivially known to be one.
1120struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1121 using OpRewritePattern<LaunchOp>::OpRewritePattern;
1122 LogicalResult matchAndRewrite(LaunchOp op,
1123 PatternRewriter &rewriter) const override {
1124 // If the range implies a single value for `id`, replace `id`'s uses by
1125 // zero.
1126 Value zero;
1127 bool simplified = false;
1128 auto constPropIdUses = [&](Value id, Value size) {
1129 // Check if size is trivially one.
1130 if (!matchPattern(size, m_One()))
1131 return;
1132 if (id.getUses().empty())
1133 return;
1134 if (!simplified) {
1135 // Create a zero value the first time.
1136 OpBuilder::InsertionGuard guard(rewriter);
1137 rewriter.setInsertionPointToStart(&op.getBody().front());
1138 zero =
1139 arith::ConstantIndexOp::create(rewriter, op.getLoc(), /*value=*/0);
1140 }
1141 rewriter.replaceAllUsesWith(id, zero);
1142 simplified = true;
1143 };
1144 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1145 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1146 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1147 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1148 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1149 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1150
1151 return success(simplified);
1152 }
1153};
1154
1155void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1156 MLIRContext *context) {
1157 rewrites.add<FoldLaunchArguments>(context);
1158}
1159
1160/// Adds a new block argument that corresponds to buffers located in
1161/// workgroup memory.
1162BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1163 auto attrName = getNumWorkgroupAttributionsAttrName();
1164 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1165 (*this)->setAttr(attrName,
1166 IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1167 return getBody().insertArgument(
1168 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1169}
1170
1171/// Adds a new block argument that corresponds to buffers located in
1172/// private memory.
1173BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1174 // Buffers on the private memory always come after buffers on the workgroup
1175 // memory.
1176 return getBody().addArgument(type, loc);
1177}
1178
1179//===----------------------------------------------------------------------===//
1180// LaunchFuncOp
1181//===----------------------------------------------------------------------===//
1182
1183void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1184 SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1185 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1186 ValueRange kernelOperands, Type asyncTokenType,
1187 ValueRange asyncDependencies,
1188 std::optional<KernelDim3> clusterSize) {
1189 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1190 "expected a symbol reference with a single nested reference");
1191 result.addOperands(asyncDependencies);
1192 if (asyncTokenType)
1193 result.types.push_back(builder.getType<AsyncTokenType>());
1194
1195 // Add grid and block sizes as op operands, followed by the data operands.
1196 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1198 if (clusterSize.has_value())
1199 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1200 if (dynamicSharedMemorySize)
1201 result.addOperands(dynamicSharedMemorySize);
1202 result.addOperands(kernelOperands);
1203
1204 Properties &prop = result.getOrAddProperties<Properties>();
1205 prop.kernel = kernelSymbol;
1206 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1207 // Initialize the segment sizes to 1.
1208 llvm::fill(prop.operandSegmentSizes, 1);
1209 prop.operandSegmentSizes[0] = asyncDependencies.size();
1210 if (!clusterSize.has_value()) {
1211 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1212 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1213 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1214 }
1215 prop.operandSegmentSizes[segmentSizesLen - 3] =
1216 dynamicSharedMemorySize ? 1 : 0;
1217 prop.operandSegmentSizes[segmentSizesLen - 2] =
1218 static_cast<int32_t>(kernelOperands.size());
1219 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1220}
1221
1222void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1223 GPUFuncOp kernelFunc, KernelDim3 gridSize,
1224 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1225 ValueRange kernelOperands, Type asyncTokenType,
1226 ValueRange asyncDependencies,
1227 std::optional<KernelDim3> clusterSize) {
1228 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1229 auto kernelSymbol =
1230 SymbolRefAttr::get(kernelModule.getNameAttr(),
1231 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1232 build(builder, result, kernelSymbol, gridSize, getBlockSize,
1233 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1234 asyncDependencies, clusterSize);
1235}
1236
1237void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1238 SymbolRefAttr kernel, KernelDim3 gridSize,
1239 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1240 ValueRange kernelOperands, Value asyncObject,
1241 std::optional<KernelDim3> clusterSize) {
1242 // Add grid and block sizes as op operands, followed by the data operands.
1243 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1245 if (clusterSize.has_value())
1246 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1247 if (dynamicSharedMemorySize)
1248 result.addOperands(dynamicSharedMemorySize);
1249 result.addOperands(kernelOperands);
1250 if (asyncObject)
1251 result.addOperands(asyncObject);
1252 Properties &prop = result.getOrAddProperties<Properties>();
1253 prop.kernel = kernel;
1254 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1255 // Initialize the segment sizes to 1.
1256 llvm::fill(prop.operandSegmentSizes, 1);
1257 prop.operandSegmentSizes[0] = 0;
1258 if (!clusterSize.has_value()) {
1259 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1260 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1261 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1262 }
1263 prop.operandSegmentSizes[segmentSizesLen - 3] =
1264 dynamicSharedMemorySize ? 1 : 0;
1265 prop.operandSegmentSizes[segmentSizesLen - 2] =
1266 static_cast<int32_t>(kernelOperands.size());
1267 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1268}
1269
1270StringAttr LaunchFuncOp::getKernelModuleName() {
1271 return getKernel().getRootReference();
1272}
1273
1274StringAttr LaunchFuncOp::getKernelName() {
1275 return getKernel().getLeafReference();
1276}
1277
1278unsigned LaunchFuncOp::getNumKernelOperands() {
1279 return getKernelOperands().size();
1280}
1281
1282Value LaunchFuncOp::getKernelOperand(unsigned i) {
1283 return getKernelOperands()[i];
1284}
1285
1286KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1287 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1288 return KernelDim3{operands[0], operands[1], operands[2]};
1289}
1290
1291KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1292 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1293 return KernelDim3{operands[3], operands[4], operands[5]};
1294}
1295
1296KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1297 assert(hasClusterSize() &&
1298 "cluster size is not set, check hasClusterSize() first");
1299 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1300 return KernelDim3{operands[6], operands[7], operands[8]};
1301}
1302
1303LogicalResult LaunchFuncOp::verify() {
1304 auto module = (*this)->getParentOfType<ModuleOp>();
1305 if (!module)
1306 return emitOpError("expected to belong to a module");
1307
1308 if (!module->getAttrOfType<UnitAttr>(
1309 GPUDialect::getContainerModuleAttrName()))
1310 return emitOpError("expected the closest surrounding module to have the '" +
1311 GPUDialect::getContainerModuleAttrName() +
1312 "' attribute");
1313
1314 if (hasClusterSize()) {
1315 if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1316 getClusterSizeZ().getType() != getClusterSizeX().getType())
1317 return emitOpError()
1318 << "expects types of the cluster dimensions must be the same";
1319 }
1320
1321 return success();
1322}
1323
1324LogicalResult
1325LaunchFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1326 LaunchFuncOp launchOp = *this;
1327 Operation *table = SymbolTable::getNearestSymbolTable(launchOp);
1328 // GPU modules cannot be nested within each other, escape to resolve the name.
1329 if (isa<GPUModuleOp>(table))
1331
1332 // Ignore launches that are nested more or less deep than functions in the
1333 // module we are currently checking.
1334 if (!launchOp->getParentOp() ||
1335 launchOp->getParentOp()->getParentOp() != table)
1336 return success();
1337
1338 // Ignore launch ops with missing attributes here. The errors will be
1339 // reported by the verifiers of those ops.
1340 if (!launchOp->getAttrOfType<SymbolRefAttr>(
1341 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
1342 return success();
1343
1344 // Check that `launch_func` refers to a well-formed GPU kernel container.
1345 StringAttr kernelContainerName = launchOp.getKernelModuleName();
1346 Operation *kernelContainer =
1347 symbolTable.lookupNearestSymbolFrom(table, kernelContainerName);
1348 if (!kernelContainer)
1349 return launchOp.emitOpError()
1350 << "kernel container '" << kernelContainerName.getValue()
1351 << "' is undefined";
1352
1353 // If the container is a GPU binary op return success.
1354 if (isa<BinaryOp>(kernelContainer))
1355 return success();
1356
1357 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
1358 if (!kernelModule)
1359 return launchOp.emitOpError()
1360 << "kernel module '" << kernelContainerName.getValue()
1361 << "' is undefined";
1362
1363 // Check that `launch_func` refers to a well-formed kernel function.
1364 Operation *kernelFunc = symbolTable.lookupNearestSymbolFrom(
1365 kernelModule, launchOp.getKernelName());
1366 if (!kernelFunc)
1367 return launchOp.emitOpError("kernel function '")
1368 << launchOp.getKernel() << "' is undefined";
1369 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
1370 if (!kernelConvertedFunction) {
1371 InFlightDiagnostic diag = launchOp.emitOpError()
1372 << "referenced kernel '" << launchOp.getKernel()
1373 << "' is not a function";
1374 diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
1375 return diag;
1376 }
1377
1378 if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
1379 GPUDialect::getKernelFuncAttrName()))
1380 return launchOp.emitOpError("kernel function is missing the '")
1381 << GPUDialect::getKernelFuncAttrName() << "' attribute";
1382
1383 // TODO: If the kernel isn't a GPU function (which happens during separate
1384 // compilation), do not check type correspondence as it would require the
1385 // verifier to be aware of the type conversion.
1386 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
1387 if (!kernelGPUFunction)
1388 return success();
1389
1390 unsigned actualNumArguments = launchOp.getNumKernelOperands();
1391 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
1392 if (expectedNumArguments != actualNumArguments)
1393 return launchOp.emitOpError("got ")
1394 << actualNumArguments << " kernel operands but expected "
1395 << expectedNumArguments;
1396
1397 FunctionType functionType = kernelGPUFunction.getFunctionType();
1398 for (unsigned i = 0; i < expectedNumArguments; ++i) {
1399 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
1400 return launchOp.emitOpError("type of function argument ")
1401 << i << " does not match";
1402 }
1403 }
1404
1405 return success();
1406}
1407
1408static ParseResult
1410 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1411 Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1412 if (succeeded(parser.parseOptionalColon())) {
1413 if (parser.parseType(dimTy))
1414 return failure();
1415 } else {
1416 dimTy = IndexType::get(parser.getContext());
1417 }
1418 if (clusterValue.has_value()) {
1419 clusterXTy = clusterYTy = clusterZTy = dimTy;
1420 }
1421 return success();
1422}
1423
1424static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1425 Value clusterValue, Type clusterXTy,
1426 Type clusterYTy, Type clusterZTy) {
1427 if (!dimTy.isIndex())
1428 printer << ": " << dimTy;
1429}
1430
1431static ParseResult parseLaunchFuncOperands(
1432 OpAsmParser &parser,
1434 SmallVectorImpl<Type> &argTypes) {
1435 if (parser.parseOptionalKeyword("args"))
1436 return success();
1437
1438 auto parseElement = [&]() -> ParseResult {
1439 return failure(parser.parseOperand(argNames.emplace_back()) ||
1440 parser.parseColonType(argTypes.emplace_back()));
1441 };
1442
1444 parseElement, " in argument list");
1445}
1446
1448 OperandRange operands, TypeRange types) {
1449 if (operands.empty())
1450 return;
1451 printer << "args(";
1452 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1453 [&](const auto &pair) {
1454 auto [operand, type] = pair;
1455 printer << operand << " : " << type;
1456 });
1457 printer << ")";
1458}
1459
1460//===----------------------------------------------------------------------===//
1461// ShuffleOp
1462//===----------------------------------------------------------------------===//
1463
1464void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1465 int32_t offset, int32_t width, ShuffleMode mode) {
1466 build(builder, result, value,
1467 arith::ConstantOp::create(builder, result.location,
1468 builder.getI32IntegerAttr(offset)),
1469 arith::ConstantOp::create(builder, result.location,
1470 builder.getI32IntegerAttr(width)),
1471 mode);
1472}
1473
1474//===----------------------------------------------------------------------===//
1475// RotateOp
1476//===----------------------------------------------------------------------===//
1477
1478LogicalResult RotateOp::verify() {
1479 uint32_t offset = getOffset();
1480 uint32_t width = getWidth();
1481
1482 if (offset >= width) {
1483 return emitOpError() << "offset must be in the range [0, " << width << ")";
1484 }
1485
1486 return success();
1487}
1488
1489//===----------------------------------------------------------------------===//
1490// BarrierOp
1491//===----------------------------------------------------------------------===//
1492
1493/// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1494static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1495 PatternRewriter &rewriter) {
1496 auto nextOp = dyn_cast_or_null<BarrierOp>(op->getNextNode());
1497 if (!nextOp)
1498 return failure();
1499
1500 std::optional<ArrayAttr> thisMemfence = op.getAddressSpaces();
1501 std::optional<ArrayAttr> nextMemfence = nextOp.getAddressSpaces();
1502
1503 if (thisMemfence) {
1504 rewriter.modifyOpInPlace(op, [&]() {
1505 if (!nextMemfence) {
1506 op.removeAddressSpacesAttr();
1507 return;
1508 }
1509 // Fast path - merge where the two barriers fence the same spaces.
1510 if (*thisMemfence == *nextMemfence) {
1511 return;
1512 }
1513
1514 llvm::SmallSetVector<Attribute, 4> mergedSpaces;
1515 for (Attribute attr : *thisMemfence)
1516 mergedSpaces.insert(attr);
1517 for (Attribute attr : *nextMemfence)
1518 mergedSpaces.insert(attr);
1519 op.setAddressSpacesAttr(rewriter.getArrayAttr(mergedSpaces.takeVector()));
1520 });
1521 }
1522
1523 rewriter.eraseOp(nextOp);
1524 return success();
1525}
1526
1527void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1528 MLIRContext *context) {
1530}
1531
1532void BarrierOp::build(mlir::OpBuilder &odsBuilder,
1533 mlir::OperationState &odsState,
1534 std::optional<AddressSpace> addressSpace) {
1535 ArrayAttr addressSpacesAttr;
1536 if (addressSpace)
1537 addressSpacesAttr = odsBuilder.getArrayAttr(
1538 AddressSpaceAttr::get(odsBuilder.getContext(), addressSpace.value()));
1539 build(odsBuilder, odsState, addressSpacesAttr);
1540}
1541
1542/// Builds a barrier that causes memory operations affecting `memrefToFence` to
1543/// be completed after the barrier is concluded. Currently, this means setting
1544/// the fenced address spaces to those of the given memref if it is a gpu
1545/// address space.
1546void BarrierOp::build(OpBuilder &builder, OperationState &odsState,
1547 Value memrefToFence) {
1548 std::optional<AddressSpace> addrSpaceToFence;
1549 if (auto memrefType = dyn_cast<BaseMemRefType>(memrefToFence.getType()))
1550 if (auto addrSpaceAttr = dyn_cast_if_present<gpu::AddressSpaceAttr>(
1551 memrefType.getMemorySpace()))
1552 addrSpaceToFence = addrSpaceAttr.getValue();
1553 return build(builder, odsState, addrSpaceToFence);
1554}
1555
1556//===----------------------------------------------------------------------===//
1557// GPUFuncOp
1558//===----------------------------------------------------------------------===//
1559
1560/// Adds a new block argument that corresponds to buffers located in
1561/// workgroup memory.
1562BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1563 auto attrName = getNumWorkgroupAttributionsAttrName();
1564 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1565 (*this)->setAttr(attrName,
1566 IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1567 return getBody().insertArgument(
1568 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1569}
1570
1571/// Adds a new block argument that corresponds to buffers located in
1572/// private memory.
1573BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1574 // Buffers on the private memory always come after buffers on the workgroup
1575 // memory.
1576 return getBody().addArgument(type, loc);
1577}
1578
1579void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1580 StringRef name, FunctionType type,
1581 TypeRange workgroupAttributions,
1582 TypeRange privateAttributions,
1583 ArrayRef<NamedAttribute> attrs) {
1584 OpBuilder::InsertionGuard g(builder);
1585
1587 builder.getStringAttr(name));
1588 result.addAttribute(getFunctionTypeAttrName(result.name),
1589 TypeAttr::get(type));
1590 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1591 builder.getI64IntegerAttr(workgroupAttributions.size()));
1592 result.addAttributes(attrs);
1593 Region *body = result.addRegion();
1594 Block *entryBlock = builder.createBlock(body);
1595
1596 // TODO: Allow passing in proper locations here.
1597 for (Type argTy : type.getInputs())
1598 entryBlock->addArgument(argTy, result.location);
1599 for (Type argTy : workgroupAttributions)
1600 entryBlock->addArgument(argTy, result.location);
1601 for (Type argTy : privateAttributions)
1602 entryBlock->addArgument(argTy, result.location);
1603}
1604
1605/// Parses a GPU function memory attribution.
1606///
1607/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1608/// (`private` `(` ssa-id-and-type-list `)`)?
1609///
1610/// Note that this function parses only one of the two similar parts, with the
1611/// keyword provided as argument.
1612static ParseResult
1613parseAttributions(OpAsmParser &parser, StringRef keyword,
1615 Attribute &attributionAttrs) {
1616 // If we could not parse the keyword, just assume empty list and succeed.
1617 if (failed(parser.parseOptionalKeyword(keyword)))
1618 return success();
1619
1620 size_t existingArgs = args.size();
1621 ParseResult result =
1623 /*allowType=*/true, /*allowAttrs=*/true);
1624 if (failed(result))
1625 return result;
1626
1627 bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1628 [](const OpAsmParser::Argument &arg) -> bool {
1629 return arg.attrs && !arg.attrs.empty();
1630 });
1631 if (!hadAttrs) {
1632 attributionAttrs = nullptr;
1633 return result;
1634 }
1635
1636 Builder &builder = parser.getBuilder();
1637 SmallVector<Attribute> attributionAttrsVec;
1638 for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1639 if (!argument.attrs)
1640 attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1641 else
1642 attributionAttrsVec.push_back(argument.attrs);
1643 }
1644 attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1645 return result;
1646}
1647
1648/// Parses a GPU function.
1649///
1650/// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1651/// (`->` function-result-list)? memory-attribution `kernel`?
1652/// function-attributes? region
1653ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
1654 SmallVector<OpAsmParser::Argument> entryArgs;
1655 SmallVector<DictionaryAttr> resultAttrs;
1656 SmallVector<Type> resultTypes;
1657 bool isVariadic;
1658
1659 // Parse the function name.
1660 StringAttr nameAttr;
1662 result.attributes))
1663 return failure();
1664
1665 auto signatureLocation = parser.getCurrentLocation();
1667 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1668 resultAttrs)))
1669 return failure();
1670
1671 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1672 return parser.emitError(signatureLocation)
1673 << "gpu.func requires named arguments";
1674
1675 // Construct the function type. More types will be added to the region, but
1676 // not to the function type.
1677 Builder &builder = parser.getBuilder();
1678
1679 SmallVector<Type> argTypes;
1680 for (auto &arg : entryArgs)
1681 argTypes.push_back(arg.type);
1682 auto type = builder.getFunctionType(argTypes, resultTypes);
1683 result.addAttribute(getFunctionTypeAttrName(result.name),
1684 TypeAttr::get(type));
1685
1687 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1688 getResAttrsAttrName(result.name));
1689
1690 Attribute workgroupAttributionAttrs;
1691 // Parse workgroup memory attributions.
1692 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1693 entryArgs, workgroupAttributionAttrs)))
1694 return failure();
1695
1696 // Store the number of operands we just parsed as the number of workgroup
1697 // memory attributions.
1698 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1699 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1700 builder.getI64IntegerAttr(numWorkgroupAttrs));
1701 if (workgroupAttributionAttrs)
1702 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1703 workgroupAttributionAttrs);
1704
1705 Attribute privateAttributionAttrs;
1706 // Parse private memory attributions.
1707 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1708 entryArgs, privateAttributionAttrs)))
1709 return failure();
1710 if (privateAttributionAttrs)
1711 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1712 privateAttributionAttrs);
1713
1714 // Parse the kernel attribute if present.
1715 if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1716 result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1717 builder.getUnitAttr());
1718
1719 // Parse attributes.
1720 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1721 return failure();
1722
1723 // Parse the region. If no argument names were provided, take all names
1724 // (including those of attributions) from the entry block.
1725 auto *body = result.addRegion();
1726 return parser.parseRegion(*body, entryArgs);
1727}
1728
1729void GPUFuncOp::print(OpAsmPrinter &p) {
1730 p << ' ';
1731 p.printSymbolName(getName());
1732
1733 FunctionType type = getFunctionType();
1734 function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1735 /*isVariadic=*/false,
1736 type.getResults());
1737
1738 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1739 getWorkgroupAttribAttrs().value_or(nullptr));
1740 printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1741 getPrivateAttribAttrs().value_or(nullptr));
1742 if (isKernel())
1743 p << ' ' << getKernelKeyword();
1744
1746 p, *this,
1747 {getNumWorkgroupAttributionsAttrName(),
1748 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1749 getArgAttrsAttrName(), getResAttrsAttrName(),
1750 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1751 p << ' ';
1752 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1753}
1754
1755static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1756 StringAttr attrName) {
1757 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1758 if (!allAttrs || index >= allAttrs.size())
1759 return DictionaryAttr();
1760 return llvm::cast<DictionaryAttr>(allAttrs[index]);
1761}
1762
1763DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1764 return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1765}
1766
1767DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1768 return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1769}
1770
1771static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1772 DictionaryAttr value, StringAttr attrName) {
1773 MLIRContext *ctx = op.getContext();
1774 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1775 SmallVector<Attribute> elements;
1776 if (allAttrs)
1777 elements.append(allAttrs.begin(), allAttrs.end());
1778 while (elements.size() <= index)
1779 elements.push_back(DictionaryAttr::get(ctx));
1780 if (!value)
1781 elements[index] = DictionaryAttr::get(ctx);
1782 else
1783 elements[index] = value;
1784 ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1785 op->setAttr(attrName, newValue);
1786}
1787
1788void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1789 DictionaryAttr value) {
1790 setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1791}
1792
1793void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1794 DictionaryAttr value) {
1795 setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1796}
1797
1798static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1799 StringAttr name, StringAttr attrsName) {
1800 DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1801 if (!dict)
1802 return Attribute();
1803 return dict.get(name);
1804}
1805
1806Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1807 StringAttr name) {
1808 assert(index < getNumWorkgroupAttributions() &&
1809 "index must map to a workgroup attribution");
1810 return getAttributionAttr(*this, index, name,
1811 getWorkgroupAttribAttrsAttrName());
1812}
1813
1814Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1815 StringAttr name) {
1816 assert(index < getNumPrivateAttributions() &&
1817 "index must map to a private attribution");
1818 return getAttributionAttr(*this, index, name,
1819 getPrivateAttribAttrsAttrName());
1820}
1821
1822static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1823 Attribute value, StringAttr attrsName) {
1824 MLIRContext *ctx = op.getContext();
1826 DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1827 if (oldDict)
1828 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1829
1830 bool found = false;
1831 bool mustSort = true;
1832 for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1833 if (elems[i].getName() == name) {
1834 found = true;
1835 if (!value) {
1836 std::swap(elems[i], elems[elems.size() - 1]);
1837 elems.pop_back();
1838 } else {
1839 mustSort = false;
1840 elems[i] = NamedAttribute(elems[i].getName(), value);
1841 }
1842 break;
1843 }
1844 }
1845 if (!found) {
1846 if (!value)
1847 return;
1848 elems.emplace_back(name, value);
1849 }
1850 if (mustSort) {
1851 DictionaryAttr::sortInPlace(elems);
1852 }
1853 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1854 setAttributionAttrs(op, index, newDict, attrsName);
1855}
1856
1857void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1858 Attribute value) {
1859 assert(index < getNumWorkgroupAttributions() &&
1860 "index must map to a workgroup attribution");
1861 setAttributionAttr(*this, index, name, value,
1862 getWorkgroupAttribAttrsAttrName());
1863}
1864
1865void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1866 Attribute value) {
1867 assert(index < getNumPrivateAttributions() &&
1868 "index must map to a private attribution");
1869 setAttributionAttr(*this, index, name, value,
1870 getPrivateAttribAttrsAttrName());
1871}
1872
1873LogicalResult GPUFuncOp::verifyType() {
1874 if (isKernel() && getFunctionType().getNumResults() != 0)
1875 return emitOpError() << "expected void return type for kernel function";
1876
1877 return success();
1878}
1879
1880/// Verifies the body of the function.
1881LogicalResult GPUFuncOp::verifyBody() {
1882 if (empty())
1883 return emitOpError() << "expected body with at least one block";
1884 unsigned numFuncArguments = getNumArguments();
1885 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1886 unsigned numBlockArguments = front().getNumArguments();
1887 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1888 return emitOpError() << "expected at least "
1889 << numFuncArguments + numWorkgroupAttributions
1890 << " arguments to body region";
1891
1892 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1893 for (unsigned i = 0; i < numFuncArguments; ++i) {
1894 Type blockArgType = front().getArgument(i).getType();
1895 if (funcArgTypes[i] != blockArgType)
1896 return emitOpError() << "expected body region argument #" << i
1897 << " to be of type " << funcArgTypes[i] << ", got "
1898 << blockArgType;
1899 }
1900
1901 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1902 GPUDialect::getWorkgroupAddressSpace())) ||
1903 failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1904 GPUDialect::getPrivateAddressSpace())))
1905 return failure();
1906
1907 return success();
1908}
1909
1910//===----------------------------------------------------------------------===//
1911// ReturnOp
1912//===----------------------------------------------------------------------===//
1913
1914LogicalResult gpu::ReturnOp::verify() {
1915 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1916
1917 FunctionType funType = function.getFunctionType();
1918
1919 if (funType.getNumResults() != getOperands().size())
1920 return emitOpError()
1921 .append("expected ", funType.getNumResults(), " result operands")
1922 .attachNote(function.getLoc())
1923 .append("return type declared here");
1924
1925 for (const auto &pair : llvm::enumerate(
1926 llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1927 auto [type, operand] = pair.value();
1928 if (type != operand.getType())
1929 return emitOpError() << "unexpected type `" << operand.getType()
1930 << "' for operand #" << pair.index();
1931 }
1932 return success();
1933}
1934
1935//===----------------------------------------------------------------------===//
1936// GPUModuleOp
1937//===----------------------------------------------------------------------===//
1938
1939void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1940 StringRef name, ArrayAttr targets,
1941 Attribute offloadingHandler) {
1942 result.addRegion()->emplaceBlock();
1943 Properties &props = result.getOrAddProperties<Properties>();
1944 if (targets)
1945 props.targets = targets;
1946 props.setSymName(builder.getStringAttr(name));
1947 props.offloadingHandler = offloadingHandler;
1948}
1949
1950void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1951 StringRef name, ArrayRef<Attribute> targets,
1952 Attribute offloadingHandler) {
1953 build(builder, result, name,
1954 targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1955 offloadingHandler);
1956}
1957
1958bool GPUModuleOp::hasTarget(Attribute target) {
1959 if (ArrayAttr targets = getTargetsAttr())
1960 return llvm::count(targets.getValue(), target);
1961 return false;
1962}
1963
1964void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1965 ArrayAttr &targetsAttr = getProperties().targets;
1966 SmallVector<Attribute> targetsVector(targets);
1967 targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1968}
1969
1970LogicalResult GPUModuleOp::verify() {
1971 auto targets = getOperation()->getAttrOfType<ArrayAttr>("targets");
1972
1973 if (!targets)
1974 return success();
1975
1976 for (auto target : targets) {
1977 if (auto verifyTargetAttr =
1978 llvm::dyn_cast<TargetAttrVerifyInterface>(target)) {
1979 if (verifyTargetAttr.verifyTarget(getOperation()).failed())
1980 return failure();
1981 }
1982 }
1983 return success();
1984}
1985
1986//===----------------------------------------------------------------------===//
1987// GPUBinaryOp
1988//===----------------------------------------------------------------------===//
1989void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1990 Attribute offloadingHandler, ArrayAttr objects) {
1991 auto &properties = result.getOrAddProperties<Properties>();
1992 result.attributes.push_back(builder.getNamedAttr(
1994 properties.objects = objects;
1995 if (offloadingHandler)
1996 properties.offloadingHandler = offloadingHandler;
1997 else
1998 properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
1999}
2000
2001void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
2002 Attribute offloadingHandler, ArrayRef<Attribute> objects) {
2003 build(builder, result, name, offloadingHandler,
2004 objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
2005}
2006
2007static ParseResult parseOffloadingHandler(OpAsmParser &parser,
2008 Attribute &offloadingHandler) {
2009 if (succeeded(parser.parseOptionalLess())) {
2010 if (parser.parseAttribute(offloadingHandler))
2011 return failure();
2012 if (parser.parseGreater())
2013 return failure();
2014 }
2015 if (!offloadingHandler)
2016 offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
2017 return success();
2018}
2019
2021 Attribute offloadingHandler) {
2022 if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
2023 printer << '<' << offloadingHandler << '>';
2024}
2025
2026//===----------------------------------------------------------------------===//
2027// GPUMemcpyOp
2028//===----------------------------------------------------------------------===//
2029
2030LogicalResult MemcpyOp::verify() {
2031 auto srcType = getSrc().getType();
2032 auto dstType = getDst().getType();
2033
2034 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
2035 return emitOpError("arguments have incompatible element type");
2036
2037 if (failed(verifyCompatibleShape(srcType, dstType)))
2038 return emitOpError("arguments have incompatible shape");
2039
2040 return success();
2041}
2042
2043namespace {
2044
2045/// Erases a common case of copy ops where a destination value is used only by
2046/// the copy op, alloc and dealloc ops.
2047struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
2048 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
2049
2050 LogicalResult matchAndRewrite(MemcpyOp op,
2051 PatternRewriter &rewriter) const override {
2052 Value dest = op.getDst();
2053 Operation *destDefOp = dest.getDefiningOp();
2054 // `dest` must be defined by an op having Allocate memory effect in order to
2055 // perform the folding.
2056 if (!destDefOp ||
2058 return failure();
2059 // We can erase `op` iff `dest` has no other use apart from its
2060 // use by `op` and dealloc ops.
2061 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
2062 return user != op &&
2063 !hasSingleEffect<MemoryEffects::Free>(user, dest);
2064 }))
2065 return failure();
2066 // We can perform the folding if and only if op has a single async
2067 // dependency and produces an async token as result, or if it does not have
2068 // any async dependency and does not produce any async token result.
2069 if (op.getAsyncDependencies().size() > 1 ||
2070 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2071 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2072 return failure();
2073 rewriter.replaceOp(op, op.getAsyncDependencies());
2074 return success();
2075 }
2076};
2077
2078} // end anonymous namespace
2079
2080void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2081 MLIRContext *context) {
2082 results.add<EraseTrivialCopyOp>(context);
2083}
2084
2085//===----------------------------------------------------------------------===//
2086// GPU_SubgroupMmaLoadMatrixOp
2087//===----------------------------------------------------------------------===//
2088
2089LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2090 auto srcType = getSrcMemref().getType();
2091 auto resType = getRes().getType();
2092 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2093 auto operand = resMatrixType.getOperand();
2094 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2095
2096 if (!srcMemrefType.isLastDimUnitStride())
2097 return emitError(
2098 "expected source memref most minor dim must have unit stride");
2099
2100 if (operand != "AOp" && operand != "BOp" && operand != "COp")
2101 return emitError("only AOp, BOp and COp can be loaded");
2102
2103 return success();
2104}
2105
2106//===----------------------------------------------------------------------===//
2107// GPU_SubgroupMmaStoreMatrixOp
2108//===----------------------------------------------------------------------===//
2109
2110LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2111 auto srcType = getSrc().getType();
2112 auto dstType = getDstMemref().getType();
2113 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2114 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2115
2116 if (!dstMemrefType.isLastDimUnitStride())
2117 return emitError(
2118 "expected destination memref most minor dim must have unit stride");
2119
2120 if (srcMatrixType.getOperand() != "COp")
2121 return emitError(
2122 "expected the operand matrix being stored to have 'COp' operand type");
2123
2124 return success();
2125}
2126
2127//===----------------------------------------------------------------------===//
2128// GPU_SubgroupMmaComputeOp
2129//===----------------------------------------------------------------------===//
2130
2131LogicalResult SubgroupMmaComputeOp::verify() {
2132 enum OperandMap { A, B, C };
2133 SmallVector<MMAMatrixType, 3> opTypes;
2134 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
2135 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
2136 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
2137
2138 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
2139 opTypes[C].getOperand() != "COp")
2140 return emitError("operands must be in the order AOp, BOp, COp");
2141
2142 ArrayRef<int64_t> aShape, bShape, cShape;
2143 aShape = opTypes[A].getShape();
2144 bShape = opTypes[B].getShape();
2145 cShape = opTypes[C].getShape();
2146
2147 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2148 bShape[1] != cShape[1])
2149 return emitError("operand shapes do not satisfy matmul constraints");
2150
2151 return success();
2152}
2153
2154LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2155 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2156 return memref::foldMemRefCast(*this);
2157}
2158
2159LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2160 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2161 return memref::foldMemRefCast(*this);
2162}
2163
2164//===----------------------------------------------------------------------===//
2165// GPU_WaitOp
2166//===----------------------------------------------------------------------===//
2167
2168namespace {
2169
2170/// Remove gpu.wait op use of gpu.wait op def without async dependencies.
2171/// %t = gpu.wait async [] // No async dependencies.
2172/// ... gpu.wait ... [%t, ...] // %t can be removed.
2173struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2174public:
2176
2177 LogicalResult matchAndRewrite(WaitOp op,
2178 PatternRewriter &rewriter) const final {
2179 auto predicate = [](Value value) {
2180 auto waitOp = value.getDefiningOp<WaitOp>();
2181 return waitOp && waitOp->getNumOperands() == 0;
2182 };
2183 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2184 return failure();
2185 SmallVector<Value> validOperands;
2186 for (Value operand : op->getOperands()) {
2187 if (predicate(operand))
2188 continue;
2189 validOperands.push_back(operand);
2190 }
2191 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2192 return success();
2193 }
2194};
2195
2196/// Simplify trivial gpu.wait ops for the following patterns.
2197/// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2198/// dependencies).
2199/// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2200/// %t0.
2201/// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2202/// dependencies nor return any token.
2203struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2204public:
2206
2207 LogicalResult matchAndRewrite(WaitOp op,
2208 PatternRewriter &rewriter) const final {
2209 // Erase gpu.wait ops that neither have any async dependencies nor return
2210 // any async token.
2211 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2212 rewriter.eraseOp(op);
2213 return success();
2214 }
2215 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2216 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2217 op.getAsyncToken()) {
2218 rewriter.replaceOp(op, op.getAsyncDependencies());
2219 return success();
2220 }
2221 // Erase %t = gpu.wait async ... ops, where %t has no uses.
2222 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2223 rewriter.eraseOp(op);
2224 return success();
2225 }
2226 return failure();
2227 }
2228};
2229
2230} // end anonymous namespace
2231
2232void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2233 MLIRContext *context) {
2234 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2235}
2236
2237//===----------------------------------------------------------------------===//
2238// GPU_AllocOp
2239//===----------------------------------------------------------------------===//
2240
2241LogicalResult AllocOp::verify() {
2242 auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2243
2244 if (failed(verifyDynamicDimensionCount(getOperation(), memRefType,
2245 getDynamicSizes())))
2246 return failure();
2247
2248 unsigned numSymbols = 0;
2249 if (!memRefType.getLayout().isIdentity())
2250 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2251 if (getSymbolOperands().size() != numSymbols) {
2252 return emitOpError(
2253 "symbol operand count does not equal memref symbol count");
2254 }
2255
2256 return success();
2257}
2258
2259namespace {
2260
2261/// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2262/// `memref::AllocOp`.
2263struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2264 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2265
2266 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2267 PatternRewriter &rewriter) const override {
2268 std::optional<int64_t> index = dimOp.getConstantIndex();
2269 if (!index)
2270 return failure();
2271
2272 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2273 if (!memrefType || index.value() >= memrefType.getRank() ||
2274 !memrefType.isDynamicDim(index.value()))
2275 return failure();
2276
2277 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2278 if (!alloc)
2279 return failure();
2280
2281 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2282 memrefType.getDynamicDimIndex(index.value()));
2283 rewriter.replaceOp(dimOp, substituteOp);
2284 return success();
2285 }
2286};
2287
2288} // namespace
2289
2290void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2291 MLIRContext *context) {
2292 results.add<SimplifyDimOfAllocOp>(context);
2293}
2294
2295//===----------------------------------------------------------------------===//
2296// GPU object attribute
2297//===----------------------------------------------------------------------===//
2298
2299LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2300 Attribute target, CompilationTarget format,
2301 StringAttr object, DictionaryAttr properties,
2302 KernelTableAttr kernels) {
2303 if (!target)
2304 return emitError() << "the target attribute cannot be null";
2305 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2306 return success();
2307 return emitError() << "the target attribute must implement or promise the "
2308 "`gpu::TargetAttrInterface`";
2309}
2310
2311namespace {
2312ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2313 StringAttr &object) {
2314 std::optional<CompilationTarget> formatResult;
2315 StringRef enumKeyword;
2316 auto loc = odsParser.getCurrentLocation();
2317 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2318 formatResult = CompilationTarget::Fatbin;
2319 if (!formatResult &&
2320 (formatResult =
2321 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2322 odsParser.parseEqual())
2323 return odsParser.emitError(loc, "expected an equal sign");
2324 if (!formatResult)
2325 return odsParser.emitError(loc, "expected keyword for GPU object format");
2326 FailureOr<StringAttr> objectResult =
2327 FieldParser<StringAttr>::parse(odsParser);
2328 if (failed(objectResult))
2329 return odsParser.emitError(odsParser.getCurrentLocation(),
2330 "failed to parse GPU_ObjectAttr parameter "
2331 "'object' which is to be a `StringAttr`");
2332 format = *formatResult;
2333 object = *objectResult;
2334 return success();
2335}
2336
2337void printObject(AsmPrinter &odsParser, CompilationTarget format,
2338 StringAttr object) {
2339 if (format != CompilationTarget::Fatbin)
2340 odsParser << stringifyEnum(format) << " = ";
2341 odsParser << object;
2342}
2343} // namespace
2344
2345//===----------------------------------------------------------------------===//
2346// GPU select object attribute
2347//===----------------------------------------------------------------------===//
2348
2349LogicalResult
2350gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2351 Attribute target) {
2352 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2353 if (target) {
2354 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2355 if (intAttr.getInt() < 0) {
2356 return emitError() << "the object index must be positive";
2357 }
2358 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2359 return emitError()
2360 << "the target attribute must be a GPU Target attribute";
2361 }
2362 }
2363 return success();
2364}
2365
2366//===----------------------------------------------------------------------===//
2367// DynamicSharedMemoryOp
2368//===----------------------------------------------------------------------===//
2369
2370LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2371 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2372 return emitOpError() << "must be inside an op with symbol table";
2373
2374 MemRefType memrefType = getResultMemref().getType();
2375 // Check address space
2376 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2377 return emitOpError() << "address space must be "
2378 << gpu::AddressSpaceAttr::getMnemonic() << "<"
2379 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2380 }
2381 if (memrefType.hasStaticShape()) {
2382 return emitOpError() << "result memref type must be memref<?xi8, "
2383 "#gpu.address_space<workgroup>>";
2384 }
2385 return success();
2386}
2387
2388//===----------------------------------------------------------------------===//
2389// GPU WarpExecuteOnLane0Op
2390//===----------------------------------------------------------------------===//
2391
2392void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2393 p << "(" << getLaneid() << ")";
2394
2395 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2396 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2397 p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2398
2399 if (!getArgs().empty())
2400 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2401 if (!getResults().empty())
2402 p << " -> (" << getResults().getTypes() << ')';
2403 p << " ";
2404 p.printRegion(getRegion(),
2405 /*printEntryBlockArgs=*/true,
2406 /*printBlockTerminators=*/!getResults().empty());
2407 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2408}
2409
2410ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2411 OperationState &result) {
2412 // Create the region.
2413 result.regions.reserve(1);
2414 Region *warpRegion = result.addRegion();
2415
2416 auto &builder = parser.getBuilder();
2417 OpAsmParser::UnresolvedOperand laneId;
2418
2419 // Parse predicate operand.
2420 if (parser.parseLParen() ||
2421 parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2422 parser.parseRParen())
2423 return failure();
2424
2425 int64_t warpSize;
2426 if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2427 parser.parseRSquare())
2428 return failure();
2429 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2430 builder.getContext())),
2431 builder.getI64IntegerAttr(warpSize));
2432
2433 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2434 return failure();
2435
2436 llvm::SMLoc inputsOperandsLoc;
2437 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2438 SmallVector<Type> inputTypes;
2439 if (succeeded(parser.parseOptionalKeyword("args"))) {
2440 if (parser.parseLParen())
2441 return failure();
2442
2443 inputsOperandsLoc = parser.getCurrentLocation();
2444 if (parser.parseOperandList(inputsOperands) ||
2445 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2446 return failure();
2447 }
2448 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2449 result.operands))
2450 return failure();
2451
2452 // Parse optional results type list.
2453 if (parser.parseOptionalArrowTypeList(result.types))
2454 return failure();
2455 // Parse the region.
2456 if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2457 /*argTypes=*/{}))
2458 return failure();
2459 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2460
2461 // Parse the optional attribute list.
2462 if (parser.parseOptionalAttrDict(result.attributes))
2463 return failure();
2464 return success();
2465}
2466
2467void WarpExecuteOnLane0Op::getSuccessorRegions(
2468 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2469 if (!point.isParent()) {
2470 regions.push_back(RegionSuccessor::parent());
2471 return;
2472 }
2473
2474 // The warp region is always executed
2475 regions.push_back(RegionSuccessor(&getWarpRegion()));
2476}
2477
2478ValueRange WarpExecuteOnLane0Op::getSuccessorInputs(RegionSuccessor successor) {
2479 return successor.isParent() ? ValueRange(getResults()) : ValueRange();
2480}
2481void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2482 TypeRange resultTypes, Value laneId,
2483 int64_t warpSize) {
2484 build(builder, result, resultTypes, laneId, warpSize,
2485 /*operands=*/{}, /*argTypes=*/{});
2486}
2487
2488void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2489 TypeRange resultTypes, Value laneId,
2490 int64_t warpSize, ValueRange args,
2491 TypeRange blockArgTypes) {
2492 result.addOperands(laneId);
2493 result.addAttribute(getAttributeNames()[0],
2494 builder.getI64IntegerAttr(warpSize));
2495 result.addTypes(resultTypes);
2496 result.addOperands(args);
2497 assert(args.size() == blockArgTypes.size());
2498 OpBuilder::InsertionGuard guard(builder);
2499 Region *warpRegion = result.addRegion();
2500 Block *block = builder.createBlock(warpRegion);
2501 for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2502 block->addArgument(type, arg.getLoc());
2503}
2504
2505/// Helper check if the distributed vector type is consistent with the expanded
2506/// type and distributed size.
2507static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2508 int64_t warpSize, Operation *op) {
2509 // If the types matches there is no distribution.
2510 if (expanded == distributed)
2511 return success();
2512 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2513 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2514 if (!expandedVecType || !distributedVecType)
2515 return op->emitOpError("expected vector type for distributed operands.");
2516 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2517 expandedVecType.getElementType() != distributedVecType.getElementType())
2518 return op->emitOpError(
2519 "expected distributed vectors to have same rank and element type.");
2520
2521 SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2522 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2523 int64_t eDim = expandedVecType.getDimSize(i);
2524 int64_t dDim = distributedVecType.getDimSize(i);
2525 if (eDim == dDim)
2526 continue;
2527 if (eDim % dDim != 0)
2528 return op->emitOpError()
2529 << "expected expanded vector dimension #" << i << " (" << eDim
2530 << ") to be a multipler of the distributed vector dimension ("
2531 << dDim << ")";
2532 scales[i] = eDim / dDim;
2533 }
2534 if (llvm::product_of(scales) != warpSize)
2535 return op->emitOpError()
2536 << "incompatible distribution dimensions from " << expandedVecType
2537 << " to " << distributedVecType << " with warp size = " << warpSize;
2538
2539 return success();
2540}
2541
2542LogicalResult WarpExecuteOnLane0Op::verify() {
2543 if (getArgs().size() != getWarpRegion().getNumArguments())
2544 return emitOpError(
2545 "expected same number op arguments and block arguments.");
2546 auto yield = dyn_cast<gpu::YieldOp>(getBody()->getTerminator());
2547 if (!yield)
2548 return emitOpError("expected body to be terminated with 'gpu.yield'");
2549 if (yield.getNumOperands() != getNumResults())
2550 return emitOpError(
2551 "expected same number of yield operands and return values.");
2552 int64_t warpSize = getWarpSize();
2553 for (auto [regionArg, arg] :
2554 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2555 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2556 warpSize, getOperation())))
2557 return failure();
2558 }
2559 for (auto [yieldOperand, result] :
2560 llvm::zip_equal(yield.getOperands(), getResults())) {
2561 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2562 warpSize, getOperation())))
2563 return failure();
2564 }
2565 return success();
2566}
2567bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2568 return succeeded(
2569 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2570}
2571
2572gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2573 return cast<gpu::YieldOp>(getBody()->getTerminator());
2574}
2575
2576//===----------------------------------------------------------------------===//
2577// GPU_SubgroupBroadcastOp
2578//===----------------------------------------------------------------------===//
2579
2580void gpu::SubgroupBroadcastOp::inferResultRanges(
2581 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
2582 setResultRange(getResult(), argRanges.front());
2583}
2584
2585Speculation::Speculatability gpu::SubgroupBroadcastOp::getSpeculatability() {
2586 switch (getBroadcastType()) {
2587 case BroadcastType::first_active_lane:
2588 // Cannot speculate first_lane broadcast, because speculating it across
2589 // control flow can change the active lanes.
2591 case BroadcastType::specific_lane:
2592 // Speculation should be safe as long as we inside structured control flow.
2594 }
2595 llvm_unreachable("Unknown BroadcastType");
2596}
2597
2598LogicalResult gpu::SubgroupBroadcastOp::verify() {
2599 switch (getBroadcastType()) {
2600 case BroadcastType::first_active_lane:
2601 if (getLane())
2602 return emitOpError()
2603 << "lane can only be specified for `specific_lane` broadcast";
2604 return success();
2605 case BroadcastType::specific_lane:
2606 if (!getLane())
2607 return emitOpError()
2608 << "lane must be specified for `specific_lane` broadcast";
2609 return success();
2610 }
2611 llvm_unreachable("Unknown BroadcastType");
2612}
2613
2614OpFoldResult gpu::SubgroupBroadcastOp::fold(FoldAdaptor /*adaptor*/) {
2615 // Broadcast result is always uniform.
2616 if (auto prev = getSrc().getDefiningOp<SubgroupBroadcastOp>())
2617 return prev.getResult();
2618
2619 return nullptr;
2620}
2621
2622//===----------------------------------------------------------------------===//
2623// GPU_BallotOp
2624//===----------------------------------------------------------------------===//
2625
2626// No custom implementations needed; ballot uses default behavior from ODS.
2627
2628//===----------------------------------------------------------------------===//
2629// GPU KernelMetadataAttr
2630//===----------------------------------------------------------------------===//
2631
2632KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2633 DictionaryAttr metadata) {
2634 assert(kernel && "invalid kernel");
2635 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2636 kernel.getAllArgAttrs(), metadata);
2637}
2638
2639KernelMetadataAttr
2640KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2641 FunctionOpInterface kernel,
2642 DictionaryAttr metadata) {
2643 assert(kernel && "invalid kernel");
2644 return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2645 kernel.getAllArgAttrs(), metadata);
2646}
2647
2648KernelMetadataAttr
2649KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2650 if (attrs.empty())
2651 return *this;
2652 NamedAttrList attrList;
2653 if (DictionaryAttr dict = getMetadata())
2654 attrList.append(dict);
2655 attrList.append(attrs);
2656 return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2657 attrList.getDictionary(getContext()));
2658}
2659
2660LogicalResult
2661KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2662 StringAttr name, Type functionType,
2663 ArrayAttr argAttrs, DictionaryAttr metadata) {
2664 if (name.empty())
2665 return emitError() << "the kernel name can't be empty";
2666 if (argAttrs) {
2667 if (llvm::any_of(argAttrs, [](Attribute attr) {
2668 return !llvm::isa<DictionaryAttr>(attr);
2669 }))
2670 return emitError()
2671 << "all attributes in the array must be a dictionary attribute";
2672 }
2673 return success();
2674}
2675
2676//===----------------------------------------------------------------------===//
2677// GPU KernelTableAttr
2678//===----------------------------------------------------------------------===//
2679
2680KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2681 ArrayRef<KernelMetadataAttr> kernels,
2682 bool isSorted) {
2683 // Note that `is_sorted` is always only invoked once even with assertions ON.
2684 assert((!isSorted || llvm::is_sorted(kernels)) &&
2685 "expected a sorted kernel array");
2686 // Immediately return the attribute if the array is sorted.
2687 if (isSorted || llvm::is_sorted(kernels))
2688 return Base::get(context, kernels);
2689 // Sort the array.
2690 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2691 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2692 return Base::get(context, kernelsTmp);
2693}
2694
2695KernelTableAttr KernelTableAttr::getChecked(
2696 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2697 ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2698 // Note that `is_sorted` is always only invoked once even with assertions ON.
2699 assert((!isSorted || llvm::is_sorted(kernels)) &&
2700 "expected a sorted kernel array");
2701 // Immediately return the attribute if the array is sorted.
2702 if (isSorted || llvm::is_sorted(kernels))
2703 return Base::getChecked(emitError, context, kernels);
2704 // Sort the array.
2705 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2706 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2707 return Base::getChecked(emitError, context, kernelsTmp);
2708}
2709
2710LogicalResult
2711KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2712 ArrayRef<KernelMetadataAttr> kernels) {
2713 if (kernels.size() < 2)
2714 return success();
2715 // Check that the kernels are uniquely named.
2716 if (std::adjacent_find(kernels.begin(), kernels.end(),
2717 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2718 return l.getName() == r.getName();
2719 }) != kernels.end()) {
2720 return emitError() << "expected all kernels to be uniquely named";
2721 }
2722 return success();
2723}
2724
2725KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2726 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2727 return found ? *iterator : KernelMetadataAttr();
2728}
2729
2730KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2731 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2732 return found ? *iterator : KernelMetadataAttr();
2733}
2734
2735//===----------------------------------------------------------------------===//
2736// GPU target options
2737//===----------------------------------------------------------------------===//
2738
2753
2771
2772TypeID TargetOptions::getTypeID() const { return typeID; }
2773
2774StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2775
2779
2780StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2781
2782StringRef TargetOptions::getELFSection() const { return elfSection; }
2783
2787
2788function_ref<void(llvm::Module &)>
2792
2793function_ref<void(llvm::Module &)>
2797
2798function_ref<void(llvm::Module &)>
2802
2804 return isaCallback;
2805}
2806
2807CompilationTarget TargetOptions::getCompilationTarget() const {
2808 return compilationTarget;
2809}
2810
2812 return CompilationTarget::Fatbin;
2813}
2814
2815std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2817 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2818 llvm::StringSaver stringSaver(options.first);
2819 StringRef opts = cmdOptions;
2820 // For a correct tokenization of the command line options `opts` must be
2821 // unquoted, otherwise the tokenization function returns a single string: the
2822 // unquoted `cmdOptions` -which is not the desired behavior.
2823 // Remove any quotes if they are at the beginning and end of the string:
2824 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2825 opts.consume_front("\""), opts.consume_back("\"");
2826 if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2827 opts.consume_front("'"), opts.consume_back("'");
2828#ifdef _WIN32
2829 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2830 /*MarkEOLs=*/false);
2831#else
2832 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2833 /*MarkEOLs=*/false);
2834#endif // _WIN32
2835 return options;
2836}
2837
2838std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2842
2843std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2845 size_t startPos = cmdOptions.find(startsWith);
2846 if (startPos == std::string::npos)
2847 return {llvm::BumpPtrAllocator(), SmallVector<const char *>()};
2848
2849 auto tokenized =
2850 tokenizeCmdOptions(cmdOptions.substr(startPos + startsWith.size()));
2851 cmdOptions.resize(startPos);
2852 return tokenized;
2853}
2854
2856
2857#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2858#include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2859
2860#define GET_ATTRDEF_CLASSES
2861#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2862
2863#define GET_OP_CLASSES
2864#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2865
2866#include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values, ArrayAttr attributes={})
static LogicalResult verifyDistributedType(Type expanded, Type distributed, int64_t warpSize, Operation *op)
Helper check if the distributed vector type is consistent with the expanded type and distributed size...
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices, StringRef keyword)
static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op, PatternRewriter &rewriter)
Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition TypeID.h:323
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
UnitAttr getUnitAttr()
Definition Builders.cpp:102
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:80
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition Builders.cpp:108
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:98
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:576
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:608
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool isParent() const
Returns true if branching from the parent op.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:78
bool isIndex() const
Definition Types.cpp:56
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
function_ref< void(llvm::Module &)> optimizedLlvmIRCallback
Callback invoked with LLVM IR for the device module after LLVM optimizations but before codegen.
function_ref< void(StringRef)> getISACallback() const
Returns the callback invoked with the target ISA for the device, for example PTX assembly.
TypeID getTypeID() const
Returns the typeID.
std::string toolkitPath
Path to the target toolkit.
SymbolTable * getSymbolTable() const
Returns the result of the getSymbolTableCallback callback or a nullptr if no callback was provided.
StringRef getELFSection() const
Returns the ELF section.
StringRef getCmdOptions() const
Returns the command line options.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
function_ref< void(StringRef)> isaCallback
Callback invoked with the target ISA for the device, for example PTX assembly.
CompilationTarget compilationTarget
Compilation process target format.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
function_ref< void(llvm::Module &)> initialLlvmIRCallback
Callback invoked with the initial LLVM IR for the device module.
ArrayRef< Attribute > getLibrariesToLink() const
Returns the LLVM libraries to link to.
TargetOptions(StringRef toolkitPath={}, ArrayRef< Attribute > librariesToLink={}, StringRef cmdOptions={}, StringRef elfSection={}, CompilationTarget compilationTarget=getDefaultCompilationTarget(), function_ref< SymbolTable *()> getSymbolTableCallback={}, function_ref< void(llvm::Module &)> initialLlvmIRCallback={}, function_ref< void(llvm::Module &)> linkedLlvmIRCallback={}, function_ref< void(llvm::Module &)> optimizedLlvmIRCallback={}, function_ref< void(StringRef)> isaCallback={})
Constructor initializing the toolkit path, the list of files to link to, extra command line options,...
function_ref< void(llvm::Module &)> getOptimizedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after LLVM optimizations but before c...
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
StringRef getToolkitPath() const
Returns the toolkit path.
SmallVector< Attribute > librariesToLink
List of files to link with the LLVM module.
function_ref< void(llvm::Module &)> linkedLlvmIRCallback
Callback invoked with LLVM IR for the device module after linking the device libraries.
function_ref< void(llvm::Module &)> getInitialLlvmIRCallback() const
Returns the callback invoked with the initial LLVM IR for the device module.
function_ref< SymbolTable *()> getSymbolTableCallback
Callback for obtaining the parent symbol table of all the GPU modules being serialized.
static CompilationTarget getDefaultCompilationTarget()
Returns the default compilation target: CompilationTarget::Fatbin.
function_ref< void(llvm::Module &)> getLinkedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after linking the device libraries.
std::string elfSection
ELF Section where the binary needs to be located.
CompilationTarget getCompilationTarget() const
Returns the compilation target.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
std::pair< IteratorT, bool > findAttrSorted(IteratorT first, IteratorT last, StringRef name)
Using llvm::lower_bound requires an extra string comparison to check whether the returned iterator po...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto getChecked(function_ref< InFlightDiagnostic()> emitError, MLIRContext *context, Ts &&...params)
Helper method analogous to get, but uses getChecked when available to allow graceful failure on inval...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition GPUDialect.h:39