MLIR 23.0.0git
GPUDialect.cpp
Go to the documentation of this file.
1//===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the GPU kernel-related dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
20#include "mlir/IR/Attributes.h"
21#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinOps.h"
25#include "mlir/IR/Diagnostics.h"
27#include "mlir/IR/Matchers.h"
30#include "mlir/IR/SymbolTable.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include "llvm/Support/CommandLine.h"
39#include "llvm/Support/ErrorHandling.h"
40#include "llvm/Support/FormatVariadic.h"
41#include "llvm/Support/InterleavedRange.h"
42#include "llvm/Support/StringSaver.h"
43#include <cassert>
44#include <numeric>
45#include <optional>
46
47using namespace mlir;
48using namespace mlir::gpu;
49
50#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
51
52//===----------------------------------------------------------------------===//
53// GPU Device Mapping Attributes
54//===----------------------------------------------------------------------===//
55
56int64_t GPUBlockMappingAttr::getMappingId() const {
57 return static_cast<int64_t>(getBlock());
58}
59
60bool GPUBlockMappingAttr::isLinearMapping() const {
61 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
62}
63
64int64_t GPUBlockMappingAttr::getRelativeIndex() const {
65 return isLinearMapping()
66 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
67 : getMappingId();
68}
69
70int64_t GPUWarpgroupMappingAttr::getMappingId() const {
71 return static_cast<int64_t>(getWarpgroup());
72}
73
74bool GPUWarpgroupMappingAttr::isLinearMapping() const {
75 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
76}
77
78int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
79 return isLinearMapping()
80 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
81 : getMappingId();
82}
83
84int64_t GPUWarpMappingAttr::getMappingId() const {
85 return static_cast<int64_t>(getWarp());
86}
87
88bool GPUWarpMappingAttr::isLinearMapping() const {
89 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
90}
91
92int64_t GPUWarpMappingAttr::getRelativeIndex() const {
93 return isLinearMapping()
94 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
95 : getMappingId();
96}
97
98int64_t GPUThreadMappingAttr::getMappingId() const {
99 return static_cast<int64_t>(getThread());
100}
101
102bool GPUThreadMappingAttr::isLinearMapping() const {
103 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
104}
105
106int64_t GPUThreadMappingAttr::getRelativeIndex() const {
107 return isLinearMapping()
108 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
109 : getMappingId();
110}
111
112int64_t GPULaneMappingAttr::getMappingId() const {
113 return static_cast<int64_t>(getLane());
114}
115
116bool GPULaneMappingAttr::isLinearMapping() const {
117 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
118}
119
120int64_t GPULaneMappingAttr::getRelativeIndex() const {
121 return isLinearMapping()
122 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
123 : getMappingId();
124}
125
126int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds() const { return 64; }
127
128/// 8 4 0
129/// Example mask : 0 0 0 1 1 0 1 0 0
130///
131/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
132/// Logical id for e.g. 5 (2) constructs filter (1 << 5 - 1).
133///
134/// Example mask : 0 0 0 1 1 0 1 0 0
135/// Example filter: 0 0 0 0 1 1 1 1 1
136/// Intersection : 0 0 0 0 1 0 1 0 0
137/// PopCnt : 2
138Value GPUMappingMaskAttr::createLogicalLinearMappingId(
139 OpBuilder &b, Value physicalLinearMappingId) const {
140 Location loc = physicalLinearMappingId.getLoc();
141 Value mask =
142 arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
143 Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
144 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
145 filter = arith::SubIOp::create(b, loc, filter, one);
146 Value filteredId = arith::AndIOp::create(b, loc, mask, filter);
147 return math::CtPopOp::create(b, loc, filteredId);
148}
149
150/// 8 4 0
151/// Example mask : 0 0 0 1 1 0 1 0 0
152///
153/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
154/// Logical id for e.g. 5 (2) constructs filter (1 << 5).
155///
156/// Example mask : 0 0 0 1 1 0 1 0 0
157/// Example filter: 0 0 0 1 0 0 0 0 0
158/// Intersection : 0 0 0 1 0 0 0 0 0
159/// Cmp : 1
160Value GPUMappingMaskAttr::createIsActiveIdPredicate(
161 OpBuilder &b, Value physicalLinearMappingId) const {
162 Location loc = physicalLinearMappingId.getLoc();
163 Value mask =
164 arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
165 Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
166 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
167 Value filtered = arith::AndIOp::create(b, loc, mask, filter);
168 Value zero = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(0));
169 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ne, filtered,
170 zero);
171}
172
173int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
174 return static_cast<int64_t>(getAddressSpace());
175}
176
177bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
178 llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
179}
180
181int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
182 llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
183}
184
185//===----------------------------------------------------------------------===//
186// MMAMatrixType
187//===----------------------------------------------------------------------===//
188
190 StringRef operand) {
191 return Base::get(elementType.getContext(), shape, elementType, operand);
192}
193
196 ArrayRef<int64_t> shape, Type elementType,
197 StringRef operand) {
198 return Base::getChecked(emitError, elementType.getContext(), shape,
199 elementType, operand);
200}
201
202unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
203
205 return getImpl()->getShape();
206}
207
208Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
209
210StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
211
213 return elementType.isF16() || elementType.isF32() || elementType.isF64() ||
214 elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
215 elementType.isInteger(32);
216}
217
218LogicalResult
220 ArrayRef<int64_t> shape, Type elementType,
221 StringRef operand) {
222 if (operand != "AOp" && operand != "BOp" && operand != "COp")
223 return emitError() << "operand expected to be one of AOp, BOp or COp";
224
225 if (shape.size() != 2)
226 return emitError() << "MMAMatrixType must have exactly two dimensions";
227
228 if (!MMAMatrixType::isValidElementType(elementType))
229 return emitError()
230 << "MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
231
232 return success();
233}
234
235//===----------------------------------------------------------------------===//
236// GPUDialect
237//===----------------------------------------------------------------------===//
238
239bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
240 if (!memorySpace)
241 return false;
242 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
243 return gpuAttr.getValue() == getWorkgroupAddressSpace();
244 return false;
245}
246
247bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
248 Attribute memorySpace = type.getMemorySpace();
249 return isWorkgroupMemoryAddressSpace(memorySpace);
250}
251
252bool GPUDialect::isConstantMemoryAddressSpace(Attribute memorySpace) {
253 if (!memorySpace)
254 return false;
255 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
256 return gpuAttr.getValue() == getConstantAddressSpace();
257 return false;
258}
259
260bool GPUDialect::hasConstantMemoryAddressSpace(MemRefType type) {
261 Attribute memorySpace = type.getMemorySpace();
262 return isConstantMemoryAddressSpace(memorySpace);
263}
264
265bool GPUDialect::isKernel(Operation *op) {
266 if (auto gpuFunc = dyn_cast<GPUFuncOp>(op))
267 return gpuFunc.isKernel();
268 return static_cast<bool>(
269 op->getAttrOfType<UnitAttr>(getKernelFuncAttrName()));
270}
271
272namespace {
273/// This class defines the interface for handling inlining with gpu
274/// operations.
275struct GPUInlinerInterface : public DialectInlinerInterface {
276 using DialectInlinerInterface::DialectInlinerInterface;
277
278 /// All gpu dialect ops can be inlined.
279 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
280 return true;
281 }
282};
283} // namespace
284
285void GPUDialect::initialize() {
286 addTypes<AsyncTokenType>();
287 addTypes<MMAMatrixType>();
288 addTypes<SparseDnTensorHandleType>();
289 addTypes<SparseSpMatHandleType>();
290 addTypes<SparseSpGEMMOpHandleType>();
291 addOperations<
292#define GET_OP_LIST
293#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
294 >();
295 addAttributes<
296#define GET_ATTRDEF_LIST
297#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
298 >();
299 addInterfaces<GPUInlinerInterface>();
300 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
301 TerminatorOp>();
302 declarePromisedInterfaces<ValueBoundsOpInterface, ClusterDimOp,
303 ClusterDimBlocksOp, ClusterIdOp, ClusterBlockIdOp,
304 BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp,
305 LaneIdOp, SubgroupIdOp, GlobalIdOp, NumSubgroupsOp,
306 SubgroupSizeOp, LaunchOp, SubgroupBroadcastOp>();
307 declarePromisedInterfaces<memref::IndexedAccessOpInterface,
308 SubgroupMmaLoadMatrixOp,
309 SubgroupMmaStoreMatrixOp>();
310}
311
312static std::string getSparseHandleKeyword(SparseHandleKind kind) {
313 switch (kind) {
315 return "sparse.dntensor_handle";
317 return "sparse.spmat_handle";
319 return "sparse.spgemmop_handle";
320 }
321 llvm_unreachable("unknown sparse handle kind");
322 return "";
323}
324
325Type GPUDialect::parseType(DialectAsmParser &parser) const {
326 // Parse the main keyword for the type.
327 StringRef keyword;
328 if (parser.parseKeyword(&keyword))
329 return Type();
330 MLIRContext *context = getContext();
331
332 // Handle 'async token' types.
333 if (keyword == "async.token")
334 return AsyncTokenType::get(context);
335
336 if (keyword == "mma_matrix") {
337 SMLoc beginLoc = parser.getNameLoc();
338
339 // Parse '<'.
340 if (parser.parseLess())
341 return nullptr;
342
343 // Parse the size and elementType.
344 SmallVector<int64_t> shape;
345 Type elementType;
346 if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
347 parser.parseType(elementType))
348 return nullptr;
349
350 // Parse ','
351 if (parser.parseComma())
352 return nullptr;
353
354 // Parse operand.
355 std::string operand;
356 if (failed(parser.parseOptionalString(&operand)))
357 return nullptr;
358
359 // Parse '>'.
360 if (parser.parseGreater())
361 return nullptr;
362
364 parser.getEncodedSourceLoc(beginLoc)),
365 shape, elementType, operand);
366 }
367
369 return SparseDnTensorHandleType::get(context);
371 return SparseSpMatHandleType::get(context);
373 return SparseSpGEMMOpHandleType::get(context);
374
375 parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
376 return Type();
377}
378// TODO: print refined type here. Notice that should be corresponding to the
379// parser
380void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
381 TypeSwitch<Type>(type)
382 .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
383 .Case<SparseDnTensorHandleType>([&](Type) {
385 })
386 .Case<SparseSpMatHandleType>(
388 .Case<SparseSpGEMMOpHandleType>([&](Type) {
390 })
391 .Case([&](MMAMatrixType fragTy) {
392 os << "mma_matrix<";
393 auto shape = fragTy.getShape();
394 for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
395 os << *dim << 'x';
396 os << shape.back() << 'x' << fragTy.getElementType();
397 os << ", \"" << fragTy.getOperand() << "\"" << '>';
398 })
399 .DefaultUnreachable("unexpected 'gpu' type kind");
400}
401
402static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
403 NamedAttribute attr) {
404 auto array = dyn_cast<DenseI32ArrayAttr>(attr.getValue());
405 if (!array)
406 return op->emitOpError(Twine(attr.getName()) +
407 " must be a dense i32 array");
408 if (array.size() != 3)
409 return op->emitOpError(Twine(attr.getName()) +
410 " must contain exactly 3 elements");
411 return success();
412}
413
414LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
415 NamedAttribute attr) {
416 if (attr.getName() == getKnownBlockSizeAttrHelper().getName())
417 return verifyKnownLaunchSizeAttr(op, attr);
418 if (attr.getName() == getKnownGridSizeAttrHelper().getName())
419 return verifyKnownLaunchSizeAttr(op, attr);
420 if (attr.getName() == getKnownClusterSizeAttrHelper().getName())
421 return verifyKnownLaunchSizeAttr(op, attr);
422 if (!llvm::isa<UnitAttr>(attr.getValue()) ||
423 attr.getName() != getContainerModuleAttrName())
424 return success();
425
426 auto module = dyn_cast<ModuleOp>(op);
427 if (!module)
428 return op->emitError("expected '")
429 << getContainerModuleAttrName() << "' attribute to be attached to '"
430 << ModuleOp::getOperationName() << '\'';
431 return success();
432}
433
434/// Parses an optional list of async operands with an optional leading keyword.
435/// (`async`)? (`[` ssa-id-list `]`)?
436///
437/// This method is used by the tablegen assembly format for async ops as well.
438static ParseResult parseAsyncDependencies(
439 OpAsmParser &parser, Type &asyncTokenType,
441 auto loc = parser.getCurrentLocation();
442 if (succeeded(parser.parseOptionalKeyword("async"))) {
443 if (parser.getNumResults() == 0)
444 return parser.emitError(loc, "needs to be named when marked 'async'");
445 asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
446 }
447 return parser.parseOperandList(asyncDependencies,
449}
450
451/// Prints optional async dependencies with its leading keyword.
452/// (`async`)? (`[` ssa-id-list `]`)?
453// Used by the tablegen assembly format for several async ops.
455 Type asyncTokenType,
456 OperandRange asyncDependencies) {
457 if (asyncTokenType)
458 printer << "async";
459 if (asyncDependencies.empty())
460 return;
461 if (asyncTokenType)
462 printer << ' ';
463 printer << llvm::interleaved_array(asyncDependencies);
464}
465
466// GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
467/// Parses a GPU function memory attribution.
468///
469/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
470/// (`private` `(` ssa-id-and-type-list `)`)?
471///
472/// Note that this function parses only one of the two similar parts, with the
473/// keyword provided as argument.
474static ParseResult
475parseAttributions(OpAsmParser &parser, StringRef keyword,
477 // If we could not parse the keyword, just assume empty list and succeed.
478 if (failed(parser.parseOptionalKeyword(keyword)))
479 return success();
480
482 /*allowType=*/true);
483}
484
485static void printAttributions(OpAsmPrinter &p, StringRef keyword,
487 ArrayAttr attributes = {}) {
488 if (values.empty())
489 return;
490
491 p << ' ' << keyword << '(';
492 llvm::interleaveComma(
493 llvm::enumerate(values), p, [&p, attributes](auto pair) {
494 BlockArgument v = pair.value();
495 p << v << " : " << v.getType();
496
497 size_t attributionIndex = pair.index();
498 DictionaryAttr attrs;
499 if (attributes && attributionIndex < attributes.size())
500 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
501 if (attrs)
502 p.printOptionalAttrDict(attrs.getValue());
503 });
504 p << ')';
505}
506
507/// Verifies a GPU function memory attribution.
508static LogicalResult verifyAttributions(Operation *op,
509 ArrayRef<BlockArgument> attributions,
510 gpu::AddressSpace memorySpace) {
511 for (Value v : attributions) {
512 auto type = llvm::dyn_cast<MemRefType>(v.getType());
513 if (!type)
514 return op->emitOpError() << "expected memref type in attribution";
515
516 // We can only verify the address space if it hasn't already been lowered
517 // from the AddressSpaceAttr to a target-specific numeric value.
518 auto addressSpace =
519 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
520 if (!addressSpace)
521 continue;
522 if (addressSpace.getValue() != memorySpace)
523 return op->emitOpError()
524 << "expected memory space " << stringifyAddressSpace(memorySpace)
525 << " in attribution";
526 }
527 return success();
528}
529
530//===----------------------------------------------------------------------===//
531// AllReduceOp
532//===----------------------------------------------------------------------===//
533
534static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
535 Type resType) {
536 using Kind = gpu::AllReduceOperation;
537 if (llvm::is_contained(
538 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
539 opName)) {
540 if (!isa<FloatType>(resType))
541 return failure();
542 }
543
544 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
545 Kind::AND, Kind::OR, Kind::XOR},
546 opName)) {
547 if (!isa<IntegerType>(resType))
548 return failure();
549 }
550
551 return success();
552}
553
554LogicalResult gpu::AllReduceOp::verifyRegions() {
555 if (getBody().empty() != getOp().has_value())
556 return emitError("expected either an op attribute or a non-empty body");
557 if (!getBody().empty()) {
558 if (getBody().getNumArguments() != 2)
559 return emitError("expected two region arguments");
560 for (auto argument : getBody().getArguments()) {
561 if (argument.getType() != getType())
562 return emitError("incorrect region argument type");
563 }
564 unsigned yieldCount = 0;
565 for (Block &block : getBody()) {
566 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
567 if (yield.getNumOperands() != 1)
568 return emitError("expected one gpu.yield operand");
569 if (yield.getOperand(0).getType() != getType())
570 return emitError("incorrect gpu.yield type");
571 ++yieldCount;
572 }
573 }
574 if (yieldCount == 0)
575 return emitError("expected gpu.yield op in region");
576 } else {
577 gpu::AllReduceOperation opName = *getOp();
578 if (failed(verifyReduceOpAndType(opName, getType()))) {
579 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
580 << "` reduction operation is not compatible with type "
581 << getType();
582 }
583 }
584
585 return success();
586}
587
589 auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
590 if (!launchOp)
591 return false;
592
593 Region &body = launchOp.getBody();
594 assert(!body.empty() && "Invalid region");
595
596 // Only convert ops in gpu::launch entry block for now.
597 return op->getBlock() == &body.front();
598}
599
600OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
601 if (!getUniform() && canMakeGroupOpUniform(*this)) {
602 setUniform(true);
603 return getResult();
604 }
605
606 return nullptr;
607}
608
609// TODO: Support optional custom attributes (without dialect prefix).
610static ParseResult parseAllReduceOperation(AsmParser &parser,
611 AllReduceOperationAttr &attr) {
612 StringRef enumStr;
613 if (!parser.parseOptionalKeyword(&enumStr)) {
614 std::optional<AllReduceOperation> op =
615 gpu::symbolizeAllReduceOperation(enumStr);
616 if (!op)
617 return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
618 attr = AllReduceOperationAttr::get(parser.getContext(), *op);
619 }
620 return success();
621}
622
624 AllReduceOperationAttr attr) {
625 if (attr)
626 attr.print(printer);
627}
628
629//===----------------------------------------------------------------------===//
630// SubgroupReduceOp
631//===----------------------------------------------------------------------===//
632
633LogicalResult gpu::SubgroupReduceOp::verify() {
634 Type elemType = getType();
635 if (auto vecTy = dyn_cast<VectorType>(elemType)) {
636 if (vecTy.isScalable())
637 return emitOpError() << "is not compatible with scalable vector types";
638
639 elemType = vecTy.getElementType();
640 }
641
642 gpu::AllReduceOperation opName = getOp();
643 if (failed(verifyReduceOpAndType(opName, elemType))) {
644 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
645 << "` reduction operation is not compatible with type "
646 << getType();
647 }
648
649 auto clusterSize = getClusterSize();
650 if (clusterSize) {
651 uint32_t size = *clusterSize;
652 if (!llvm::isPowerOf2_32(size)) {
653 return emitOpError() << "cluster size " << size
654 << " is not a power of two";
655 }
656 }
657
658 uint32_t stride = getClusterStride();
659 if (stride != 1 && !clusterSize) {
660 return emitOpError() << "cluster stride can only be specified if cluster "
661 "size is specified";
662 }
663 if (!llvm::isPowerOf2_32(stride)) {
664 return emitOpError() << "cluster stride " << stride
665 << " is not a power of two";
666 }
667
668 return success();
669}
670
671OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
672 if (getClusterSize() == 1)
673 return getValue();
674
675 if (!getUniform() && canMakeGroupOpUniform(*this)) {
676 setUniform(true);
677 return getResult();
678 }
679
680 return nullptr;
681}
682
683//===----------------------------------------------------------------------===//
684// AsyncOpInterface
685//===----------------------------------------------------------------------===//
686
688 op->insertOperands(0, {token});
689 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
690 return;
691 auto attrName =
693 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
694
695 // Async dependencies is the only variadic operand.
696 if (!sizeAttr)
697 return;
698
699 SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
700 ++sizes.front();
701 op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
702}
703
704//===----------------------------------------------------------------------===//
705// LaunchOp
706//===----------------------------------------------------------------------===//
707
708void LaunchOp::build(OpBuilder &builder, OperationState &result,
709 Value gridSizeX, Value gridSizeY, Value gridSizeZ,
710 Value getBlockSizeX, Value getBlockSizeY,
711 Value getBlockSizeZ, Value dynamicSharedMemorySize,
712 Type asyncTokenType, ValueRange asyncDependencies,
713 TypeRange workgroupAttributions,
714 TypeRange privateAttributions, Value clusterSizeX,
715 Value clusterSizeY, Value clusterSizeZ,
716 FlatSymbolRefAttr module, FlatSymbolRefAttr function) {
717 OpBuilder::InsertionGuard g(builder);
718
719 if (!workgroupAttributions.empty())
720 result.addAttribute(
721 getWorkgroupAttributionsAttrName(result.name),
722 builder.getI64IntegerAttr(workgroupAttributions.size()));
723
724 // Add Op operands.
725 result.addOperands(asyncDependencies);
726 if (asyncTokenType)
727 result.types.push_back(builder.getType<AsyncTokenType>());
728
729 // Add grid and block sizes as op operands, followed by the data operands.
730 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
731 getBlockSizeY, getBlockSizeZ});
732 if (clusterSizeX)
733 result.addOperands(clusterSizeX);
734 if (clusterSizeY)
735 result.addOperands(clusterSizeY);
736 if (clusterSizeZ)
737 result.addOperands(clusterSizeZ);
738 if (dynamicSharedMemorySize)
739 result.addOperands(dynamicSharedMemorySize);
740
741 // Add optional module and function attributes.
742 if (module)
743 result.addAttribute(getModuleAttrName(result.name), module);
744 if (function)
745 result.addAttribute(getFunctionAttrName(result.name), function);
746
747 // Create a kernel body region with kNumConfigRegionAttributes + N memory
748 // attributions, where the first kNumConfigRegionAttributes arguments have
749 // `index` type and the rest have the same types as the data operands.
750 Region *kernelRegion = result.addRegion();
751 Block *body = builder.createBlock(kernelRegion);
752 // TODO: Allow passing in proper locations here.
753 for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
754 body->addArgument(builder.getIndexType(), result.location);
755 // Add WorkGroup & Private attributions to the region arguments.
756 for (Type argTy : workgroupAttributions)
757 body->addArgument(argTy, result.location);
758 for (Type argTy : privateAttributions)
759 body->addArgument(argTy, result.location);
760 // Fill OperandSegmentSize Attribute.
761 SmallVector<int32_t, 11> segmentSizes(11, 1);
762 segmentSizes.front() = asyncDependencies.size();
763 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
764 segmentSizes[7] = clusterSizeX ? 1 : 0;
765 segmentSizes[8] = clusterSizeY ? 1 : 0;
766 segmentSizes[9] = clusterSizeZ ? 1 : 0;
767 result.addAttribute(getOperandSegmentSizeAttr(),
768 builder.getDenseI32ArrayAttr(segmentSizes));
769}
770
771KernelDim3 LaunchOp::getBlockIds() {
772 assert(!getBody().empty() && "LaunchOp body must not be empty.");
773 auto args = getBody().getArguments();
774 return KernelDim3{args[0], args[1], args[2]};
775}
776
777KernelDim3 LaunchOp::getThreadIds() {
778 assert(!getBody().empty() && "LaunchOp body must not be empty.");
779 auto args = getBody().getArguments();
780 return KernelDim3{args[3], args[4], args[5]};
781}
782
783KernelDim3 LaunchOp::getGridSize() {
784 assert(!getBody().empty() && "LaunchOp body must not be empty.");
785 auto args = getBody().getArguments();
786 return KernelDim3{args[6], args[7], args[8]};
787}
788
789KernelDim3 LaunchOp::getBlockSize() {
790 assert(!getBody().empty() && "LaunchOp body must not be empty.");
791 auto args = getBody().getArguments();
792 return KernelDim3{args[9], args[10], args[11]};
793}
794
795std::optional<KernelDim3> LaunchOp::getClusterIds() {
796 assert(!getBody().empty() && "LaunchOp body must not be empty.");
797 if (!hasClusterSize())
798 return std::nullopt;
799 auto args = getBody().getArguments();
800 return KernelDim3{args[12], args[13], args[14]};
801}
802
803std::optional<KernelDim3> LaunchOp::getClusterSize() {
804 assert(!getBody().empty() && "LaunchOp body must not be empty.");
805 if (!hasClusterSize())
806 return std::nullopt;
807 auto args = getBody().getArguments();
808 return KernelDim3{args[15], args[16], args[17]};
809}
810
811KernelDim3 LaunchOp::getGridSizeOperandValues() {
812 auto operands = getOperands().drop_front(getAsyncDependencies().size());
813 return KernelDim3{operands[0], operands[1], operands[2]};
814}
815
816KernelDim3 LaunchOp::getBlockSizeOperandValues() {
817 auto operands = getOperands().drop_front(getAsyncDependencies().size());
818 return KernelDim3{operands[3], operands[4], operands[5]};
819}
820
821std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
822 auto operands = getOperands().drop_front(getAsyncDependencies().size());
823 if (!hasClusterSize())
824 return std::nullopt;
825 return KernelDim3{operands[6], operands[7], operands[8]};
826}
827
828LogicalResult LaunchOp::verify() {
829 if (!(hasClusterSize()) &&
830 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
831 return emitOpError() << "cluster size must be all present";
832 return success();
833}
834
835LogicalResult LaunchOp::verifyRegions() {
836 // Kernel launch takes kNumConfigOperands leading operands for grid/block
837 // sizes and transforms them into kNumConfigRegionAttributes region arguments
838 // for block/thread identifiers and grid/block sizes.
839 if (getBody().empty()) {
840 return emitOpError("body region is empty");
841 }
842 if (getBody().getNumArguments() <
843 kNumConfigRegionAttributes + getNumWorkgroupAttributions()) {
844 return emitOpError("unexpected number of region arguments");
845 }
846
847 // Verify Attributions Address Spaces.
848 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributionBBArgs(),
849 GPUDialect::getWorkgroupAddressSpace())) ||
850 failed(verifyAttributions(getOperation(), getPrivateAttributions(),
851 GPUDialect::getPrivateAddressSpace())))
852 return failure();
853
854 // Block terminators without successors are expected to exit the kernel region
855 // and must be `gpu.terminator`.
856 for (Block &block : getBody()) {
857 if (block.empty())
858 continue;
859 if (block.back().getNumSuccessors() != 0)
860 continue;
861 if (!isa<gpu::TerminatorOp>(&block.back())) {
862 return block.back()
863 .emitError()
864 .append("expected '", gpu::TerminatorOp::getOperationName(),
865 "' or a terminator with successors")
866 .attachNote(getLoc())
867 .append("in '", LaunchOp::getOperationName(), "' body region");
868 }
869 }
870
871 if (getNumResults() == 0 && getAsyncToken())
872 return emitOpError("needs to be named when async keyword is specified");
873
874 return success();
875}
876
877// Pretty-print the kernel grid/block size assignment as
878// (%iter-x, %iter-y, %iter-z) in
879// (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
880// where %size-* and %iter-* will correspond to the body region arguments.
882 KernelDim3 operands, KernelDim3 ids) {
883 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
884 p << size.x << " = " << operands.x << ", ";
885 p << size.y << " = " << operands.y << ", ";
886 p << size.z << " = " << operands.z << ')';
887}
888
889void LaunchOp::print(OpAsmPrinter &p) {
890 if (getAsyncToken()) {
891 p << " async";
892 if (!getAsyncDependencies().empty())
893 p << " [" << getAsyncDependencies() << ']';
894 }
895 // Print the launch configuration.
896 if (hasClusterSize()) {
897 p << ' ' << getClustersKeyword();
898 printSizeAssignment(p, getClusterSize().value(),
899 getClusterSizeOperandValues().value(),
900 getClusterIds().value());
901 }
902 p << ' ' << getBlocksKeyword();
903 printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
904 getBlockIds());
905 p << ' ' << getThreadsKeyword();
906 printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
907 getThreadIds());
908 if (getDynamicSharedMemorySize())
909 p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
910 << getDynamicSharedMemorySize();
911
912 // Print optional module attribute.
913 StringRef moduleAttrName = getModuleAttrName();
914 if (auto module = getModule()) {
915 p << ' ' << moduleAttrName << '(';
916 p.printSymbolName(*module);
917 p << ')';
918 }
919 // Print optional function attribute.
920 StringRef functionAttrName = getFunctionAttrName();
921 if (auto function = getFunction()) {
922 p << ' ' << functionAttrName << '(';
923 p.printSymbolName(*function);
924 p << ')';
925 }
926
927 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributionBBArgs());
928 printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
929
930 p << ' ';
931
932 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
933 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
934 LaunchOp::getOperandSegmentSizeAttr(),
935 getWorkgroupAttributionsAttrName(),
936 moduleAttrName, functionAttrName});
937}
938
939// Parse the size assignment blocks for blocks and threads. These have the form
940// (%region_arg, %region_arg, %region_arg) in
941// (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
942// where %region_arg are percent-identifiers for the region arguments to be
943// introduced further (SSA defs), and %operand are percent-identifiers for the
944// SSA value uses.
945static ParseResult
950 StringRef keyword) {
951 assert(indices.size() == 3 && "space for three indices expected");
954 /*allowResultNumber=*/false) ||
955 parser.parseKeyword("in") || parser.parseLParen())
956 return failure();
957
958 if (args.size() != 3) {
959 return parser.emitError(parser.getNameLoc())
960 << keyword << " expects 3 arguments, but got " << args.size();
961 }
962 std::move(args.begin(), args.end(), indices.begin());
963
964 for (int i = 0; i < 3; ++i) {
965 if (i != 0 && parser.parseComma())
966 return failure();
967 if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
968 parser.parseEqual() || parser.parseOperand(sizes[i]))
969 return failure();
970 }
971
972 return parser.parseRParen();
973}
974
975/// Parses a Launch operation.
976/// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
977/// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
978/// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
979/// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
980/// (`dynamic_shared_memory_size` ssa-use)?
981/// (`module(` symbol-ref-id `)`)?
982/// (`function(` symbol-ref-id `)`)?
983/// memory-attribution
984/// region attr-dict?
985/// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
986ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
987 // Sizes of the grid and block.
988 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
989 sizes(LaunchOp::kNumConfigOperands);
990
991 // Region arguments to be created.
992 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
993 LaunchOp::kNumConfigRegionAttributes);
994
995 // Parse optional async dependencies.
996 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
997 Type asyncTokenType;
998 if (failed(
999 parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
1000 parser.resolveOperands(asyncDependencies, asyncTokenType,
1001 result.operands))
1002 return failure();
1003 if (parser.getNumResults() > 0) {
1004 if (!asyncTokenType)
1005 return parser.emitError(
1006 parser.getNameLoc(),
1007 "gpu.launch requires 'async' keyword to return a value");
1008 result.types.push_back(asyncTokenType);
1009 }
1010
1011 bool hasCluster = false;
1012 if (succeeded(parser.parseOptionalKeyword(LaunchOp::getClustersKeyword()))) {
1013 hasCluster = true;
1014 sizes.resize(9);
1015 regionArgs.resize(18);
1016 }
1017 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
1018 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
1019
1020 // Last three segment assigns the cluster size. In the region argument
1021 // list, this is last 6 arguments.
1022 if (hasCluster) {
1024 parser, sizesRef.drop_front(6), regionArgsRef.slice(15, 3),
1025 regionArgsRef.slice(12, 3), LaunchOp::getClustersKeyword()))
1026 return failure();
1027 }
1028 // Parse the size assignment segments: the first segment assigns grid sizes
1029 // and defines values for block identifiers; the second segment assigns block
1030 // sizes and defines values for thread identifiers. In the region argument
1031 // list, identifiers precede sizes, and block-related values precede
1032 // thread-related values.
1033 if (parser.parseKeyword(LaunchOp::getBlocksKeyword()) ||
1034 parseSizeAssignment(parser, sizesRef.take_front(3),
1035 regionArgsRef.slice(6, 3), regionArgsRef.slice(0, 3),
1036 LaunchOp::getBlocksKeyword()) ||
1037 parser.parseKeyword(LaunchOp::getThreadsKeyword()) ||
1038 parseSizeAssignment(parser, sizesRef.drop_front(3),
1039 regionArgsRef.slice(9, 3), regionArgsRef.slice(3, 3),
1040 LaunchOp::getThreadsKeyword()) ||
1041 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
1042 result.operands))
1043 return failure();
1044
1045 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
1046 bool hasDynamicSharedMemorySize = false;
1047 if (!parser.parseOptionalKeyword(
1048 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1049 hasDynamicSharedMemorySize = true;
1050 if (parser.parseOperand(dynamicSharedMemorySize) ||
1051 parser.resolveOperand(dynamicSharedMemorySize,
1052 parser.getBuilder().getI32Type(),
1053 result.operands))
1054 return failure();
1055 }
1056
1057 // Parse optional module attribute.
1058 StringRef moduleAttrName = getModuleAttrName(result.name);
1059 if (succeeded(parser.parseOptionalKeyword(moduleAttrName))) {
1060 FlatSymbolRefAttr moduleSymbol;
1061 if (parser.parseLParen() ||
1062 parser.parseAttribute(moduleSymbol, Type(), moduleAttrName,
1063 result.attributes) ||
1064 parser.parseRParen())
1065 return failure();
1066 }
1067 // Parse optional function attribute.
1068 StringRef functionAttrName = getFunctionAttrName(result.name);
1069 if (succeeded(parser.parseOptionalKeyword(functionAttrName))) {
1070 FlatSymbolRefAttr funcSymbol;
1071 if (parser.parseLParen() ||
1072 parser.parseAttribute(funcSymbol, Type(), functionAttrName,
1073 result.attributes) ||
1074 parser.parseRParen())
1075 return failure();
1076 }
1077
1078 // Create the region arguments: fixed launch-config args (`index`), then
1079 // workgroup / private attribution args. The workgroup count is stored in the
1080 // inherent `workgroup_attributions` attribute when non-zero.
1081 Type index = parser.getBuilder().getIndexType();
1082 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
1083 LaunchOp::kNumConfigRegionAttributes + 6, index);
1084
1085 SmallVector<OpAsmParser::Argument> regionArguments;
1086 for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1087 OpAsmParser::Argument arg;
1088 arg.ssaName = std::get<0>(ssaValueAndType);
1089 arg.type = std::get<1>(ssaValueAndType);
1090 regionArguments.push_back(arg);
1091 }
1092
1093 Builder &builder = parser.getBuilder();
1094 // Parse workgroup memory attributions.
1095 if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
1096 regionArguments)))
1097 return failure();
1098
1099 // Store the number of operands we just parsed as the number of workgroup
1100 // memory attributions.
1101 unsigned numWorkgroupAttrs = regionArguments.size() -
1102 LaunchOp::kNumConfigRegionAttributes -
1103 (hasCluster ? 6 : 0);
1104 if (numWorkgroupAttrs != 0)
1105 result.addAttribute(LaunchOp::getWorkgroupAttributionsAttrName(result.name),
1106 builder.getI64IntegerAttr(numWorkgroupAttrs));
1107
1108 // Parse private memory attributions.
1109 if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
1110 regionArguments)))
1111 return failure();
1112
1113 // Introduce the body region and parse it. The region has
1114 // kNumConfigRegionAttributes arguments that correspond to
1115 // block/thread identifiers and grid/block sizes, all having `index` type.
1116 Region *body = result.addRegion();
1117 if (parser.parseRegion(*body, regionArguments) ||
1118 parser.parseOptionalAttrDict(result.attributes))
1119 return failure();
1120
1121 SmallVector<int32_t, 11> segmentSizes(11, 1);
1122 segmentSizes.front() = asyncDependencies.size();
1123
1124 if (!hasCluster) {
1125 segmentSizes[7] = 0;
1126 segmentSizes[8] = 0;
1127 segmentSizes[9] = 0;
1128 }
1129 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1130 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1131 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1132 return success();
1133}
1134
1135/// Simplify the gpu.launch when the range of a thread or block ID is
1136/// trivially known to be one.
1137struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1138 using OpRewritePattern<LaunchOp>::OpRewritePattern;
1139 LogicalResult matchAndRewrite(LaunchOp op,
1140 PatternRewriter &rewriter) const override {
1141 // If the range implies a single value for `id`, replace `id`'s uses by
1142 // zero.
1143 Value zero;
1144 bool simplified = false;
1145 auto constPropIdUses = [&](Value id, Value size) {
1146 // Check if size is trivially one.
1147 if (!matchPattern(size, m_One()))
1148 return;
1149 if (id.getUses().empty())
1150 return;
1151 if (!simplified) {
1152 // Create a zero value the first time.
1153 OpBuilder::InsertionGuard guard(rewriter);
1154 rewriter.setInsertionPointToStart(&op.getBody().front());
1155 zero =
1156 arith::ConstantIndexOp::create(rewriter, op.getLoc(), /*value=*/0);
1157 }
1158 rewriter.replaceAllUsesWith(id, zero);
1159 simplified = true;
1160 };
1161 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1162 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1163 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1164 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1165 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1166 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1167
1168 return success(simplified);
1169 }
1170};
1171
1172void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1173 MLIRContext *context) {
1174 rewrites.add<FoldLaunchArguments>(context);
1175}
1176
1177/// Adds a new block argument that corresponds to buffers located in
1178/// workgroup memory.
1179BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1180 int64_t cur = getWorkgroupAttributions().value_or(0);
1181 setWorkgroupAttributions(std::optional<int64_t>(cur + 1));
1182 return getBody().insertArgument(
1183 getNumConfigRegionAttributes() + static_cast<unsigned>(cur), type, loc);
1184}
1185
1186/// Adds a new block argument that corresponds to buffers located in
1187/// private memory.
1188BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1189 // Buffers on the private memory always come after buffers on the workgroup
1190 // memory.
1191 return getBody().addArgument(type, loc);
1192}
1193
1194//===----------------------------------------------------------------------===//
1195// LaunchFuncOp
1196//===----------------------------------------------------------------------===//
1197
1198void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1199 SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1200 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1201 ValueRange kernelOperands, Type asyncTokenType,
1202 ValueRange asyncDependencies,
1203 std::optional<KernelDim3> clusterSize) {
1204 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1205 "expected a symbol reference with a single nested reference");
1206 result.addOperands(asyncDependencies);
1207 if (asyncTokenType)
1208 result.types.push_back(builder.getType<AsyncTokenType>());
1209
1210 // Add grid and block sizes as op operands, followed by the data operands.
1211 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1213 if (clusterSize.has_value())
1214 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1215 if (dynamicSharedMemorySize)
1216 result.addOperands(dynamicSharedMemorySize);
1217 result.addOperands(kernelOperands);
1218
1219 Properties &prop = result.getOrAddProperties<Properties>();
1220 prop.kernel = kernelSymbol;
1221 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1222 // Initialize the segment sizes to 1.
1223 llvm::fill(prop.operandSegmentSizes, 1);
1224 prop.operandSegmentSizes[0] = asyncDependencies.size();
1225 if (!clusterSize.has_value()) {
1226 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1227 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1228 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1229 }
1230 prop.operandSegmentSizes[segmentSizesLen - 3] =
1231 dynamicSharedMemorySize ? 1 : 0;
1232 prop.operandSegmentSizes[segmentSizesLen - 2] =
1233 static_cast<int32_t>(kernelOperands.size());
1234 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1235}
1236
1237void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1238 GPUFuncOp kernelFunc, KernelDim3 gridSize,
1239 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1240 ValueRange kernelOperands, Type asyncTokenType,
1241 ValueRange asyncDependencies,
1242 std::optional<KernelDim3> clusterSize) {
1243 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1244 auto kernelSymbol =
1245 SymbolRefAttr::get(kernelModule.getNameAttr(),
1246 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1247 build(builder, result, kernelSymbol, gridSize, getBlockSize,
1248 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1249 asyncDependencies, clusterSize);
1250}
1251
1252void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1253 SymbolRefAttr kernel, KernelDim3 gridSize,
1254 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1255 ValueRange kernelOperands, Value asyncObject,
1256 std::optional<KernelDim3> clusterSize) {
1257 // Add grid and block sizes as op operands, followed by the data operands.
1258 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1260 if (clusterSize.has_value())
1261 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1262 if (dynamicSharedMemorySize)
1263 result.addOperands(dynamicSharedMemorySize);
1264 result.addOperands(kernelOperands);
1265 if (asyncObject)
1266 result.addOperands(asyncObject);
1267 Properties &prop = result.getOrAddProperties<Properties>();
1268 prop.kernel = kernel;
1269 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1270 // Initialize the segment sizes to 1.
1271 llvm::fill(prop.operandSegmentSizes, 1);
1272 prop.operandSegmentSizes[0] = 0;
1273 if (!clusterSize.has_value()) {
1274 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1275 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1276 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1277 }
1278 prop.operandSegmentSizes[segmentSizesLen - 3] =
1279 dynamicSharedMemorySize ? 1 : 0;
1280 prop.operandSegmentSizes[segmentSizesLen - 2] =
1281 static_cast<int32_t>(kernelOperands.size());
1282 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1283}
1284
1285StringAttr LaunchFuncOp::getKernelModuleName() {
1286 return getKernel().getRootReference();
1287}
1288
1289StringAttr LaunchFuncOp::getKernelName() {
1290 return getKernel().getLeafReference();
1291}
1292
1293unsigned LaunchFuncOp::getNumKernelOperands() {
1294 return getKernelOperands().size();
1295}
1296
1297Value LaunchFuncOp::getKernelOperand(unsigned i) {
1298 return getKernelOperands()[i];
1299}
1300
1301KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1302 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1303 return KernelDim3{operands[0], operands[1], operands[2]};
1304}
1305
1306KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1307 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1308 return KernelDim3{operands[3], operands[4], operands[5]};
1309}
1310
1311KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1312 assert(hasClusterSize() &&
1313 "cluster size is not set, check hasClusterSize() first");
1314 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1315 return KernelDim3{operands[6], operands[7], operands[8]};
1316}
1317
1318LogicalResult LaunchFuncOp::verify() {
1319 auto module = (*this)->getParentOfType<ModuleOp>();
1320 if (!module)
1321 return emitOpError("expected to belong to a module");
1322
1323 if (!module->getAttrOfType<UnitAttr>(
1324 GPUDialect::getContainerModuleAttrName()))
1325 return emitOpError("expected the closest surrounding module to have the '" +
1326 GPUDialect::getContainerModuleAttrName() +
1327 "' attribute");
1328
1329 if (hasClusterSize()) {
1330 if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1331 getClusterSizeZ().getType() != getClusterSizeX().getType())
1332 return emitOpError()
1333 << "expects types of the cluster dimensions must be the same";
1334 }
1335
1336 return success();
1337}
1338
1339LogicalResult
1340LaunchFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1341 LaunchFuncOp launchOp = *this;
1342 Operation *table = SymbolTable::getNearestSymbolTable(launchOp);
1343 // GPU modules cannot be nested within each other, escape to resolve the name.
1344 if (isa<GPUModuleOp>(table))
1346
1347 // Ignore launches that are nested more or less deep than functions in the
1348 // module we are currently checking.
1349 if (!launchOp->getParentOp() ||
1350 launchOp->getParentOp()->getParentOp() != table)
1351 return success();
1352
1353 // Ignore launch ops with missing attributes here. The errors will be
1354 // reported by the verifiers of those ops.
1355 if (!launchOp->getAttrOfType<SymbolRefAttr>(
1356 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
1357 return success();
1358
1359 // Check that `launch_func` refers to a well-formed GPU kernel container.
1360 StringAttr kernelContainerName = launchOp.getKernelModuleName();
1361 Operation *kernelContainer =
1362 symbolTable.lookupNearestSymbolFrom(table, kernelContainerName);
1363 if (!kernelContainer)
1364 return launchOp.emitOpError()
1365 << "kernel container '" << kernelContainerName.getValue()
1366 << "' is undefined";
1367
1368 // If the container is a GPU binary op return success.
1369 if (isa<BinaryOp>(kernelContainer))
1370 return success();
1371
1372 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
1373 if (!kernelModule)
1374 return launchOp.emitOpError()
1375 << "kernel module '" << kernelContainerName.getValue()
1376 << "' is undefined";
1377
1378 // Check that `launch_func` refers to a well-formed kernel function.
1379 Operation *kernelFunc = symbolTable.lookupNearestSymbolFrom(
1380 kernelModule, launchOp.getKernelName());
1381 if (!kernelFunc)
1382 return launchOp.emitOpError("kernel function '")
1383 << launchOp.getKernel() << "' is undefined";
1384 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
1385 if (!kernelConvertedFunction) {
1386 InFlightDiagnostic diag = launchOp.emitOpError()
1387 << "referenced kernel '" << launchOp.getKernel()
1388 << "' is not a function";
1389 diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
1390 return diag;
1391 }
1392
1393 if (!GPUDialect::isKernel(kernelFunc))
1394 return launchOp.emitOpError("kernel function is missing the '")
1395 << GPUDialect::getKernelFuncAttrName() << "' attribute";
1396
1397 // TODO: If the kernel isn't a GPU function (which happens during separate
1398 // compilation), do not check type correspondence as it would require the
1399 // verifier to be aware of the type conversion.
1400 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
1401 if (!kernelGPUFunction)
1402 return success();
1403
1404 unsigned actualNumArguments = launchOp.getNumKernelOperands();
1405 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
1406 if (expectedNumArguments != actualNumArguments)
1407 return launchOp.emitOpError("got ")
1408 << actualNumArguments << " kernel operands but expected "
1409 << expectedNumArguments;
1410
1411 FunctionType functionType = kernelGPUFunction.getFunctionType();
1412 for (unsigned i = 0; i < expectedNumArguments; ++i) {
1413 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
1414 return launchOp.emitOpError("type of function argument ")
1415 << i << " does not match";
1416 }
1417 }
1418
1419 return success();
1420}
1421
1422static ParseResult
1424 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1425 Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1426 if (succeeded(parser.parseOptionalColon())) {
1427 if (parser.parseType(dimTy))
1428 return failure();
1429 } else {
1430 dimTy = IndexType::get(parser.getContext());
1431 }
1432 if (clusterValue.has_value()) {
1433 clusterXTy = clusterYTy = clusterZTy = dimTy;
1434 }
1435 return success();
1436}
1437
1438static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1439 Value clusterValue, Type clusterXTy,
1440 Type clusterYTy, Type clusterZTy) {
1441 if (!dimTy.isIndex())
1442 printer << ": " << dimTy;
1443}
1444
1445static ParseResult parseLaunchFuncOperands(
1446 OpAsmParser &parser,
1448 SmallVectorImpl<Type> &argTypes) {
1449 if (parser.parseOptionalKeyword("args"))
1450 return success();
1451
1452 auto parseElement = [&]() -> ParseResult {
1453 return failure(parser.parseOperand(argNames.emplace_back()) ||
1454 parser.parseColonType(argTypes.emplace_back()));
1455 };
1456
1458 parseElement, " in argument list");
1459}
1460
1462 OperandRange operands, TypeRange types) {
1463 if (operands.empty())
1464 return;
1465 printer << "args(";
1466 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1467 [&](const auto &pair) {
1468 auto [operand, type] = pair;
1469 printer << operand << " : " << type;
1470 });
1471 printer << ")";
1472}
1473
1474//===----------------------------------------------------------------------===//
1475// ShuffleOp
1476//===----------------------------------------------------------------------===//
1477
1478void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1479 int32_t offset, int32_t width, ShuffleMode mode) {
1480 build(builder, result, value,
1481 arith::ConstantOp::create(builder, result.location,
1482 builder.getI32IntegerAttr(offset)),
1483 arith::ConstantOp::create(builder, result.location,
1484 builder.getI32IntegerAttr(width)),
1485 mode);
1486}
1487
1488//===----------------------------------------------------------------------===//
1489// RotateOp
1490//===----------------------------------------------------------------------===//
1491
1492LogicalResult RotateOp::verify() {
1493 uint32_t offset = getOffset();
1494 uint32_t width = getWidth();
1495
1496 if (offset >= width) {
1497 return emitOpError() << "offset must be in the range [0, " << width << ")";
1498 }
1499
1500 return success();
1501}
1502
1503//===----------------------------------------------------------------------===//
1504// BarrierOp
1505//===----------------------------------------------------------------------===//
1506
1507/// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1508static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1509 PatternRewriter &rewriter) {
1510 auto nextOp = dyn_cast_or_null<BarrierOp>(op->getNextNode());
1511 if (!nextOp)
1512 return failure();
1513
1514 std::optional<ArrayAttr> thisMemfence = op.getAddressSpaces();
1515 std::optional<ArrayAttr> nextMemfence = nextOp.getAddressSpaces();
1516
1517 if (thisMemfence) {
1518 rewriter.modifyOpInPlace(op, [&]() {
1519 if (!nextMemfence) {
1520 op.removeAddressSpacesAttr();
1521 return;
1522 }
1523 // Fast path - merge where the two barriers fence the same spaces.
1524 if (*thisMemfence == *nextMemfence) {
1525 return;
1526 }
1527
1528 llvm::SmallSetVector<Attribute, 4> mergedSpaces;
1529 for (Attribute attr : *thisMemfence)
1530 mergedSpaces.insert(attr);
1531 for (Attribute attr : *nextMemfence)
1532 mergedSpaces.insert(attr);
1533 op.setAddressSpacesAttr(rewriter.getArrayAttr(mergedSpaces.takeVector()));
1534 });
1535 }
1536
1537 rewriter.eraseOp(nextOp);
1538 return success();
1539}
1540
1541void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1542 MLIRContext *context) {
1544}
1545
1546void BarrierOp::build(mlir::OpBuilder &odsBuilder,
1547 mlir::OperationState &odsState,
1548 std::optional<AddressSpace> addressSpace) {
1549 ArrayAttr addressSpacesAttr;
1550 if (addressSpace)
1551 addressSpacesAttr = odsBuilder.getArrayAttr(
1552 AddressSpaceAttr::get(odsBuilder.getContext(), addressSpace.value()));
1553 build(odsBuilder, odsState, addressSpacesAttr);
1554}
1555
1556/// Builds a barrier that causes memory operations affecting `memrefToFence` to
1557/// be completed after the barrier is concluded. Currently, this means setting
1558/// the fenced address spaces to those of the given memref if it is a gpu
1559/// address space.
1560void BarrierOp::build(OpBuilder &builder, OperationState &odsState,
1561 Value memrefToFence) {
1562 std::optional<AddressSpace> addrSpaceToFence;
1563 if (auto memrefType = dyn_cast<BaseMemRefType>(memrefToFence.getType()))
1564 if (auto addrSpaceAttr = dyn_cast_if_present<gpu::AddressSpaceAttr>(
1565 memrefType.getMemorySpace()))
1566 addrSpaceToFence = addrSpaceAttr.getValue();
1567 return build(builder, odsState, addrSpaceToFence);
1568}
1569
1570//===----------------------------------------------------------------------===//
1571// GPUFuncOp
1572//===----------------------------------------------------------------------===//
1573
1574/// Adds a new block argument that corresponds to buffers located in
1575/// workgroup memory.
1576BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1577 int64_t cur = getWorkgroupAttributions().value_or(0);
1578 setWorkgroupAttributions(std::optional<int64_t>(cur + 1));
1579 return getBody().insertArgument(
1580 getFunctionType().getNumInputs() + static_cast<unsigned>(cur), type, loc);
1581}
1582
1583/// Adds a new block argument that corresponds to buffers located in
1584/// private memory.
1585BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1586 // Buffers on the private memory always come after buffers on the workgroup
1587 // memory.
1588 return getBody().addArgument(type, loc);
1589}
1590
1591void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1592 StringRef name, FunctionType type,
1593 TypeRange workgroupAttributions,
1594 TypeRange privateAttributions,
1595 ArrayRef<NamedAttribute> attrs) {
1596 OpBuilder::InsertionGuard g(builder);
1597
1599 builder.getStringAttr(name));
1600 result.addAttribute(getFunctionTypeAttrName(result.name),
1601 TypeAttr::get(type));
1602 result.addAttribute(getWorkgroupAttributionsAttrName(result.name),
1603 builder.getI64IntegerAttr(workgroupAttributions.size()));
1604 result.addAttributes(attrs);
1605 Region *body = result.addRegion();
1606 Block *entryBlock = builder.createBlock(body);
1607
1608 // TODO: Allow passing in proper locations here.
1609 for (Type argTy : type.getInputs())
1610 entryBlock->addArgument(argTy, result.location);
1611 for (Type argTy : workgroupAttributions)
1612 entryBlock->addArgument(argTy, result.location);
1613 for (Type argTy : privateAttributions)
1614 entryBlock->addArgument(argTy, result.location);
1615}
1616
1617/// Parses a GPU function memory attribution.
1618///
1619/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1620/// (`private` `(` ssa-id-and-type-list `)`)?
1621///
1622/// Note that this function parses only one of the two similar parts, with the
1623/// keyword provided as argument.
1624static ParseResult
1625parseAttributions(OpAsmParser &parser, StringRef keyword,
1627 Attribute &attributionAttrs) {
1628 // If we could not parse the keyword, just assume empty list and succeed.
1629 if (failed(parser.parseOptionalKeyword(keyword)))
1630 return success();
1631
1632 size_t existingArgs = args.size();
1633 ParseResult result =
1635 /*allowType=*/true, /*allowAttrs=*/true);
1636 if (failed(result))
1637 return result;
1638
1639 bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1640 [](const OpAsmParser::Argument &arg) -> bool {
1641 return arg.attrs && !arg.attrs.empty();
1642 });
1643 if (!hadAttrs) {
1644 attributionAttrs = nullptr;
1645 return result;
1646 }
1647
1648 Builder &builder = parser.getBuilder();
1649 SmallVector<Attribute> attributionAttrsVec;
1650 for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1651 if (!argument.attrs)
1652 attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1653 else
1654 attributionAttrsVec.push_back(argument.attrs);
1655 }
1656 attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1657 return result;
1658}
1659
1660/// Parses a GPU function.
1661///
1662/// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1663/// (`->` function-result-list)? memory-attribution `kernel`?
1664/// function-attributes? region
1665ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
1666 SmallVector<OpAsmParser::Argument> entryArgs;
1667 SmallVector<DictionaryAttr> resultAttrs;
1668 SmallVector<Type> resultTypes;
1669 bool isVariadic;
1670
1671 // Parse the function name.
1672 StringAttr nameAttr;
1674 result.attributes))
1675 return failure();
1676
1677 auto signatureLocation = parser.getCurrentLocation();
1679 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1680 resultAttrs)))
1681 return failure();
1682
1683 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1684 return parser.emitError(signatureLocation)
1685 << "gpu.func requires named arguments";
1686
1687 // Construct the function type. More types will be added to the region, but
1688 // not to the function type.
1689 Builder &builder = parser.getBuilder();
1690
1691 SmallVector<Type> argTypes;
1692 for (auto &arg : entryArgs)
1693 argTypes.push_back(arg.type);
1694 auto type = builder.getFunctionType(argTypes, resultTypes);
1695 result.addAttribute(getFunctionTypeAttrName(result.name),
1696 TypeAttr::get(type));
1697
1699 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1700 getResAttrsAttrName(result.name));
1701
1702 Attribute workgroupAttributionAttrs;
1703 // Parse workgroup memory attributions.
1704 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1705 entryArgs, workgroupAttributionAttrs)))
1706 return failure();
1707
1708 // Store the number of operands we just parsed as the number of workgroup
1709 // memory attributions.
1710 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1711 if (numWorkgroupAttrs != 0)
1712 result.addAttribute(
1713 GPUFuncOp::getWorkgroupAttributionsAttrName(result.name),
1714 builder.getI64IntegerAttr(numWorkgroupAttrs));
1715 if (workgroupAttributionAttrs)
1716 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1717 workgroupAttributionAttrs);
1718
1719 Attribute privateAttributionAttrs;
1720 // Parse private memory attributions.
1721 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1722 entryArgs, privateAttributionAttrs)))
1723 return failure();
1724 if (privateAttributionAttrs)
1725 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1726 privateAttributionAttrs);
1727
1728 // Parse the kernel attribute if present.
1729 if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1730 result.addAttribute(GPUFuncOp::getKernelAttrName(result.name),
1731 builder.getUnitAttr());
1732
1733 // Parse attributes.
1734 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1735 return failure();
1736
1737 // Parse the region. If no argument names were provided, take all names
1738 // (including those of attributions) from the entry block.
1739 auto *body = result.addRegion();
1740 return parser.parseRegion(*body, entryArgs);
1741}
1742
1743void GPUFuncOp::print(OpAsmPrinter &p) {
1744 p << ' ';
1745 p.printSymbolName(getName());
1746
1747 FunctionType type = getFunctionType();
1748 function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1749 /*isVariadic=*/false,
1750 type.getResults());
1751
1752 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributionBBArgs(),
1753 getWorkgroupAttribAttrs().value_or(nullptr));
1754 printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1755 getPrivateAttribAttrs().value_or(nullptr));
1756 if (isKernel())
1757 p << ' ' << getKernelKeyword();
1758
1760 p, *this,
1761 {getWorkgroupAttributionsAttrName(), getKernelAttrName(),
1762 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1763 getArgAttrsAttrName(), getResAttrsAttrName(),
1764 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1765 p << ' ';
1766 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1767}
1768
1769static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1770 StringAttr attrName) {
1771 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1772 if (!allAttrs || index >= allAttrs.size())
1773 return DictionaryAttr();
1774 return llvm::cast<DictionaryAttr>(allAttrs[index]);
1775}
1776
1777DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1778 return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1779}
1780
1781DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1782 return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1783}
1784
1785static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1786 DictionaryAttr value, StringAttr attrName) {
1787 MLIRContext *ctx = op.getContext();
1788 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1789 SmallVector<Attribute> elements;
1790 if (allAttrs)
1791 elements.append(allAttrs.begin(), allAttrs.end());
1792 while (elements.size() <= index)
1793 elements.push_back(DictionaryAttr::get(ctx));
1794 if (!value)
1795 elements[index] = DictionaryAttr::get(ctx);
1796 else
1797 elements[index] = value;
1798 ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1799 op->setAttr(attrName, newValue);
1800}
1801
1802void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1803 DictionaryAttr value) {
1804 setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1805}
1806
1807void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1808 DictionaryAttr value) {
1809 setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1810}
1811
1812static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1813 StringAttr name, StringAttr attrsName) {
1814 DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1815 if (!dict)
1816 return Attribute();
1817 return dict.get(name);
1818}
1819
1820Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1821 StringAttr name) {
1822 assert(index < getNumWorkgroupAttributions() &&
1823 "index must map to a workgroup attribution");
1824 return getAttributionAttr(*this, index, name,
1825 getWorkgroupAttribAttrsAttrName());
1826}
1827
1828Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1829 StringAttr name) {
1830 assert(index < getNumPrivateAttributions() &&
1831 "index must map to a private attribution");
1832 return getAttributionAttr(*this, index, name,
1833 getPrivateAttribAttrsAttrName());
1834}
1835
1836static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1837 Attribute value, StringAttr attrsName) {
1838 MLIRContext *ctx = op.getContext();
1840 DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1841 if (oldDict)
1842 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1843
1844 bool found = false;
1845 bool mustSort = true;
1846 for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1847 if (elems[i].getName() == name) {
1848 found = true;
1849 if (!value) {
1850 std::swap(elems[i], elems[elems.size() - 1]);
1851 elems.pop_back();
1852 } else {
1853 mustSort = false;
1854 elems[i] = NamedAttribute(elems[i].getName(), value);
1855 }
1856 break;
1857 }
1858 }
1859 if (!found) {
1860 if (!value)
1861 return;
1862 elems.emplace_back(name, value);
1863 }
1864 if (mustSort) {
1865 DictionaryAttr::sortInPlace(elems);
1866 }
1867 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1868 setAttributionAttrs(op, index, newDict, attrsName);
1869}
1870
1871void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1872 Attribute value) {
1873 assert(index < getNumWorkgroupAttributions() &&
1874 "index must map to a workgroup attribution");
1875 setAttributionAttr(*this, index, name, value,
1876 getWorkgroupAttribAttrsAttrName());
1877}
1878
1879void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1880 Attribute value) {
1881 assert(index < getNumPrivateAttributions() &&
1882 "index must map to a private attribution");
1883 setAttributionAttr(*this, index, name, value,
1884 getPrivateAttribAttrsAttrName());
1885}
1886
1887LogicalResult GPUFuncOp::verifyType() {
1888 if (isKernel() && getFunctionType().getNumResults() != 0)
1889 return emitOpError() << "expected void return type for kernel function";
1890
1891 return success();
1892}
1893
1894/// Verifies the body of the function.
1895LogicalResult GPUFuncOp::verifyBody() {
1896 if (empty())
1897 return emitOpError() << "expected body with at least one block";
1898 unsigned numFuncArguments = getNumArguments();
1899 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1900 unsigned numBlockArguments = front().getNumArguments();
1901 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1902 return emitOpError() << "expected at least "
1903 << numFuncArguments + numWorkgroupAttributions
1904 << " arguments to body region";
1905
1906 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1907 for (unsigned i = 0; i < numFuncArguments; ++i) {
1908 Type blockArgType = front().getArgument(i).getType();
1909 if (funcArgTypes[i] != blockArgType)
1910 return emitOpError() << "expected body region argument #" << i
1911 << " to be of type " << funcArgTypes[i] << ", got "
1912 << blockArgType;
1913 }
1914
1915 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributionBBArgs(),
1916 GPUDialect::getWorkgroupAddressSpace())) ||
1917 failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1918 GPUDialect::getPrivateAddressSpace())))
1919 return failure();
1920
1921 return success();
1922}
1923
1924//===----------------------------------------------------------------------===//
1925// ReturnOp
1926//===----------------------------------------------------------------------===//
1927
1928LogicalResult gpu::ReturnOp::verify() {
1929 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1930
1931 FunctionType funType = function.getFunctionType();
1932
1933 if (funType.getNumResults() != getOperands().size())
1934 return emitOpError()
1935 .append("expected ", funType.getNumResults(), " result operands")
1936 .attachNote(function.getLoc())
1937 .append("return type declared here");
1938
1939 for (const auto &pair : llvm::enumerate(
1940 llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1941 auto [type, operand] = pair.value();
1942 if (type != operand.getType())
1943 return emitOpError() << "unexpected type `" << operand.getType()
1944 << "' for operand #" << pair.index();
1945 }
1946 return success();
1947}
1948
1949//===----------------------------------------------------------------------===//
1950// GPUModuleOp
1951//===----------------------------------------------------------------------===//
1952
1953void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1954 StringRef name, ArrayAttr targets,
1955 Attribute offloadingHandler) {
1956 result.addRegion()->emplaceBlock();
1957 Properties &props = result.getOrAddProperties<Properties>();
1958 if (targets)
1959 props.targets = targets;
1960 props.setSymName(builder.getStringAttr(name));
1961 props.offloadingHandler = offloadingHandler;
1962}
1963
1964void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1965 StringRef name, ArrayRef<Attribute> targets,
1966 Attribute offloadingHandler) {
1967 build(builder, result, name,
1968 targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1969 offloadingHandler);
1970}
1971
1972bool GPUModuleOp::hasTarget(Attribute target) {
1973 if (ArrayAttr targets = getTargetsAttr())
1974 return llvm::count(targets.getValue(), target);
1975 return false;
1976}
1977
1978void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1979 ArrayAttr &targetsAttr = getProperties().targets;
1980 SmallVector<Attribute> targetsVector(targets);
1981 targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1982}
1983
1984LogicalResult GPUModuleOp::verify() {
1985 auto targets = getOperation()->getAttrOfType<ArrayAttr>("targets");
1986
1987 if (!targets)
1988 return success();
1989
1990 for (auto target : targets) {
1991 if (auto verifyTargetAttr =
1992 llvm::dyn_cast<TargetAttrVerifyInterface>(target)) {
1993 if (verifyTargetAttr.verifyTarget(getOperation()).failed())
1994 return failure();
1995 }
1996 }
1997 return success();
1998}
1999
2000//===----------------------------------------------------------------------===//
2001// GPUBinaryOp
2002//===----------------------------------------------------------------------===//
2003void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
2004 Attribute offloadingHandler, ArrayAttr objects) {
2005 auto &properties = result.getOrAddProperties<Properties>();
2006 result.attributes.push_back(builder.getNamedAttr(
2008 properties.objects = objects;
2009 if (offloadingHandler)
2010 properties.offloadingHandler = offloadingHandler;
2011 else
2012 properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
2013}
2014
2015void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
2016 Attribute offloadingHandler, ArrayRef<Attribute> objects) {
2017 build(builder, result, name, offloadingHandler,
2018 objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
2019}
2020
2021static ParseResult parseOffloadingHandler(OpAsmParser &parser,
2022 Attribute &offloadingHandler) {
2023 if (succeeded(parser.parseOptionalLess())) {
2024 if (parser.parseAttribute(offloadingHandler))
2025 return failure();
2026 if (parser.parseGreater())
2027 return failure();
2028 }
2029 if (!offloadingHandler)
2030 offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
2031 return success();
2032}
2033
2035 Attribute offloadingHandler) {
2036 if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
2037 printer << '<' << offloadingHandler << '>';
2038}
2039
2040//===----------------------------------------------------------------------===//
2041// GPUMemcpyOp
2042//===----------------------------------------------------------------------===//
2043
2044LogicalResult MemcpyOp::verify() {
2045 auto srcType = getSrc().getType();
2046 auto dstType = getDst().getType();
2047
2048 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
2049 return emitOpError("arguments have incompatible element type");
2050
2051 if (failed(verifyCompatibleShape(srcType, dstType)))
2052 return emitOpError("arguments have incompatible shape");
2053
2054 return success();
2055}
2056
2057namespace {
2058
2059/// Erases a common case of copy ops where a destination value is used only by
2060/// the copy op, alloc and dealloc ops.
2061struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
2062 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
2063
2064 LogicalResult matchAndRewrite(MemcpyOp op,
2065 PatternRewriter &rewriter) const override {
2066 Value dest = op.getDst();
2067 Operation *destDefOp = dest.getDefiningOp();
2068 // `dest` must be defined by an op having Allocate memory effect in order to
2069 // perform the folding.
2070 if (!destDefOp ||
2072 return failure();
2073 // We can erase `op` iff `dest` has no other use apart from its
2074 // use by `op` and dealloc ops.
2075 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
2076 return user != op &&
2077 !hasSingleEffect<MemoryEffects::Free>(user, dest);
2078 }))
2079 return failure();
2080 // We can perform the folding if and only if op has a single async
2081 // dependency and produces an async token as result, or if it does not have
2082 // any async dependency and does not produce any async token result.
2083 if (op.getAsyncDependencies().size() > 1 ||
2084 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2085 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2086 return failure();
2087 rewriter.replaceOp(op, op.getAsyncDependencies());
2088 return success();
2089 }
2090};
2091
2092} // end anonymous namespace
2093
2094void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2095 MLIRContext *context) {
2096 results.add<EraseTrivialCopyOp>(context);
2097}
2098
2099//===----------------------------------------------------------------------===//
2100// GPU_SubgroupMmaLoadMatrixOp
2101//===----------------------------------------------------------------------===//
2102
2103LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2104 auto srcType = getSrcMemref().getType();
2105 auto resType = getRes().getType();
2106 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2107 auto operand = resMatrixType.getOperand();
2108 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2109
2110 if (!srcMemrefType.isLastDimUnitStride())
2111 return emitError(
2112 "expected source memref most minor dim must have unit stride");
2113
2114 if (operand != "AOp" && operand != "BOp" && operand != "COp")
2115 return emitError("only AOp, BOp and COp can be loaded");
2116
2117 return success();
2118}
2119
2120//===----------------------------------------------------------------------===//
2121// GPU_SubgroupMmaStoreMatrixOp
2122//===----------------------------------------------------------------------===//
2123
2124LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2125 auto srcType = getSrc().getType();
2126 auto dstType = getDstMemref().getType();
2127 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2128 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2129
2130 if (!dstMemrefType.isLastDimUnitStride())
2131 return emitError(
2132 "expected destination memref most minor dim must have unit stride");
2133
2134 if (srcMatrixType.getOperand() != "COp")
2135 return emitError(
2136 "expected the operand matrix being stored to have 'COp' operand type");
2137
2138 return success();
2139}
2140
2141//===----------------------------------------------------------------------===//
2142// GPU_SubgroupMmaComputeOp
2143//===----------------------------------------------------------------------===//
2144
2145LogicalResult SubgroupMmaComputeOp::verify() {
2146 enum OperandMap { A, B, C };
2147 SmallVector<MMAMatrixType, 3> opTypes;
2148 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
2149 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
2150 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
2151
2152 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
2153 opTypes[C].getOperand() != "COp")
2154 return emitError("operands must be in the order AOp, BOp, COp");
2155
2156 ArrayRef<int64_t> aShape, bShape, cShape;
2157 aShape = opTypes[A].getShape();
2158 bShape = opTypes[B].getShape();
2159 cShape = opTypes[C].getShape();
2160
2161 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2162 bShape[1] != cShape[1])
2163 return emitError("operand shapes do not satisfy matmul constraints");
2164
2165 return success();
2166}
2167
2168LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2169 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2170 return memref::foldMemRefCast(*this);
2171}
2172
2173LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2174 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2175 return memref::foldMemRefCast(*this);
2176}
2177
2178//===----------------------------------------------------------------------===//
2179// GPU_WaitOp
2180//===----------------------------------------------------------------------===//
2181
2182namespace {
2183
2184/// Remove gpu.wait op use of gpu.wait op def without async dependencies.
2185/// %t = gpu.wait async [] // No async dependencies.
2186/// ... gpu.wait ... [%t, ...] // %t can be removed.
2187struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2188public:
2190
2191 LogicalResult matchAndRewrite(WaitOp op,
2192 PatternRewriter &rewriter) const final {
2193 auto predicate = [](Value value) {
2194 auto waitOp = value.getDefiningOp<WaitOp>();
2195 return waitOp && waitOp->getNumOperands() == 0;
2196 };
2197 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2198 return failure();
2199 SmallVector<Value> validOperands;
2200 for (Value operand : op->getOperands()) {
2201 if (predicate(operand))
2202 continue;
2203 validOperands.push_back(operand);
2204 }
2205 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2206 return success();
2207 }
2208};
2209
2210/// Simplify trivial gpu.wait ops for the following patterns.
2211/// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2212/// dependencies).
2213/// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2214/// %t0.
2215/// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2216/// dependencies nor return any token.
2217struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2218public:
2220
2221 LogicalResult matchAndRewrite(WaitOp op,
2222 PatternRewriter &rewriter) const final {
2223 // Erase gpu.wait ops that neither have any async dependencies nor return
2224 // any async token.
2225 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2226 rewriter.eraseOp(op);
2227 return success();
2228 }
2229 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2230 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2231 op.getAsyncToken()) {
2232 rewriter.replaceOp(op, op.getAsyncDependencies());
2233 return success();
2234 }
2235 // Erase %t = gpu.wait async ... ops, where %t has no uses.
2236 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2237 rewriter.eraseOp(op);
2238 return success();
2239 }
2240 return failure();
2241 }
2242};
2243
2244} // end anonymous namespace
2245
2246void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2247 MLIRContext *context) {
2248 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2249}
2250
2251//===----------------------------------------------------------------------===//
2252// GPU_AllocOp
2253//===----------------------------------------------------------------------===//
2254
2255LogicalResult AllocOp::verify() {
2256 auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2257
2258 if (failed(verifyDynamicDimensionCount(getOperation(), memRefType,
2259 getDynamicSizes())))
2260 return failure();
2261
2262 unsigned numSymbols = 0;
2263 if (!memRefType.getLayout().isIdentity())
2264 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2265 if (getSymbolOperands().size() != numSymbols) {
2266 return emitOpError(
2267 "symbol operand count does not equal memref symbol count");
2268 }
2269
2270 return success();
2271}
2272
2273namespace {
2274
2275/// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2276/// `memref::AllocOp`.
2277struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2278 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2279
2280 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2281 PatternRewriter &rewriter) const override {
2282 std::optional<int64_t> index = dimOp.getConstantIndex();
2283 if (!index)
2284 return failure();
2285
2286 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2287 if (!memrefType || index.value() >= memrefType.getRank() ||
2288 !memrefType.isDynamicDim(index.value()))
2289 return failure();
2290
2291 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2292 if (!alloc)
2293 return failure();
2294
2295 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2296 memrefType.getDynamicDimIndex(index.value()));
2297 rewriter.replaceOp(dimOp, substituteOp);
2298 return success();
2299 }
2300};
2301
2302} // namespace
2303
2304void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2305 MLIRContext *context) {
2306 results.add<SimplifyDimOfAllocOp>(context);
2307}
2308
2309//===----------------------------------------------------------------------===//
2310// GPU object attribute
2311//===----------------------------------------------------------------------===//
2312
2313LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2314 Attribute target, CompilationTarget format,
2315 StringAttr object, DictionaryAttr properties,
2316 KernelTableAttr kernels) {
2317 if (!target)
2318 return emitError() << "the target attribute cannot be null";
2319 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2320 return success();
2321 return emitError() << "the target attribute must implement or promise the "
2322 "`gpu::TargetAttrInterface`";
2323}
2324
2325namespace {
2326ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2327 StringAttr &object) {
2328 std::optional<CompilationTarget> formatResult;
2329 StringRef enumKeyword;
2330 auto loc = odsParser.getCurrentLocation();
2331 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2332 formatResult = CompilationTarget::Fatbin;
2333 if (!formatResult &&
2334 (formatResult =
2335 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2336 odsParser.parseEqual())
2337 return odsParser.emitError(loc, "expected an equal sign");
2338 if (!formatResult)
2339 return odsParser.emitError(loc, "expected keyword for GPU object format");
2340 FailureOr<StringAttr> objectResult =
2341 FieldParser<StringAttr>::parse(odsParser);
2342 if (failed(objectResult))
2343 return odsParser.emitError(odsParser.getCurrentLocation(),
2344 "failed to parse GPU_ObjectAttr parameter "
2345 "'object' which is to be a `StringAttr`");
2346 format = *formatResult;
2347 object = *objectResult;
2348 return success();
2349}
2350
2351void printObject(AsmPrinter &odsParser, CompilationTarget format,
2352 StringAttr object) {
2353 if (format != CompilationTarget::Fatbin)
2354 odsParser << stringifyEnum(format) << " = ";
2355 odsParser << object;
2356}
2357} // namespace
2358
2359//===----------------------------------------------------------------------===//
2360// GPU select object attribute
2361//===----------------------------------------------------------------------===//
2362
2363LogicalResult
2364gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2365 Attribute target) {
2366 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2367 if (target) {
2368 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2369 if (intAttr.getInt() < 0) {
2370 return emitError() << "the object index must be positive";
2371 }
2372 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2373 return emitError()
2374 << "the target attribute must be a GPU Target attribute";
2375 }
2376 }
2377 return success();
2378}
2379
2380//===----------------------------------------------------------------------===//
2381// DynamicSharedMemoryOp
2382//===----------------------------------------------------------------------===//
2383
2384LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2385 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2386 return emitOpError() << "must be inside an op with symbol table";
2387
2388 MemRefType memrefType = getResultMemref().getType();
2389 // Check address space
2390 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2391 return emitOpError() << "address space must be "
2392 << gpu::AddressSpaceAttr::getMnemonic() << "<"
2393 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2394 }
2395 if (memrefType.hasStaticShape()) {
2396 return emitOpError() << "result memref type must be memref<?xi8, "
2397 "#gpu.address_space<workgroup>>";
2398 }
2399 return success();
2400}
2401
2402//===----------------------------------------------------------------------===//
2403// GPU WarpExecuteOnLane0Op
2404//===----------------------------------------------------------------------===//
2405
2406void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2407 p << "(" << getLaneid() << ")";
2408
2409 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2410 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2411 p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2412
2413 if (!getArgs().empty())
2414 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2415 if (!getResults().empty())
2416 p << " -> (" << getResults().getTypes() << ')';
2417 p << " ";
2418 p.printRegion(getRegion(),
2419 /*printEntryBlockArgs=*/true,
2420 /*printBlockTerminators=*/!getResults().empty());
2421 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2422}
2423
2424ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2425 OperationState &result) {
2426 // Create the region.
2427 result.regions.reserve(1);
2428 Region *warpRegion = result.addRegion();
2429
2430 auto &builder = parser.getBuilder();
2431 OpAsmParser::UnresolvedOperand laneId;
2432
2433 // Parse predicate operand.
2434 if (parser.parseLParen() ||
2435 parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2436 parser.parseRParen())
2437 return failure();
2438
2439 int64_t warpSize;
2440 if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2441 parser.parseRSquare())
2442 return failure();
2443 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2444 builder.getContext())),
2445 builder.getI64IntegerAttr(warpSize));
2446
2447 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2448 return failure();
2449
2450 llvm::SMLoc inputsOperandsLoc;
2451 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2452 SmallVector<Type> inputTypes;
2453 if (succeeded(parser.parseOptionalKeyword("args"))) {
2454 if (parser.parseLParen())
2455 return failure();
2456
2457 inputsOperandsLoc = parser.getCurrentLocation();
2458 if (parser.parseOperandList(inputsOperands) ||
2459 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2460 return failure();
2461 }
2462 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2463 result.operands))
2464 return failure();
2465
2466 // Parse optional results type list.
2467 if (parser.parseOptionalArrowTypeList(result.types))
2468 return failure();
2469 // Parse the region.
2470 if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2471 /*argTypes=*/{}))
2472 return failure();
2473 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2474
2475 // Parse the optional attribute list.
2476 if (parser.parseOptionalAttrDict(result.attributes))
2477 return failure();
2478 return success();
2479}
2480
2481void WarpExecuteOnLane0Op::getSuccessorRegions(
2482 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2483 if (!point.isParent()) {
2484 regions.push_back(RegionSuccessor::parent());
2485 return;
2486 }
2487
2488 // The warp region is always executed
2489 regions.push_back(RegionSuccessor(&getWarpRegion()));
2490}
2491
2492ValueRange WarpExecuteOnLane0Op::getSuccessorInputs(RegionSuccessor successor) {
2493 return successor.isParent() ? ValueRange(getResults()) : ValueRange();
2494}
2495void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2496 TypeRange resultTypes, Value laneId,
2497 int64_t warpSize) {
2498 build(builder, result, resultTypes, laneId, warpSize,
2499 /*operands=*/{}, /*argTypes=*/{});
2500}
2501
2502void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2503 TypeRange resultTypes, Value laneId,
2504 int64_t warpSize, ValueRange args,
2505 TypeRange blockArgTypes) {
2506 result.addOperands(laneId);
2507 result.addAttribute(getAttributeNames()[0],
2508 builder.getI64IntegerAttr(warpSize));
2509 result.addTypes(resultTypes);
2510 result.addOperands(args);
2511 assert(args.size() == blockArgTypes.size());
2512 OpBuilder::InsertionGuard guard(builder);
2513 Region *warpRegion = result.addRegion();
2514 Block *block = builder.createBlock(warpRegion);
2515 for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2516 block->addArgument(type, arg.getLoc());
2517}
2518
2519/// Helper check if the distributed vector type is consistent with the expanded
2520/// type and distributed size.
2521static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2522 int64_t warpSize, Operation *op) {
2523 // If the types matches there is no distribution.
2524 if (expanded == distributed)
2525 return success();
2526 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2527 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2528 if (!expandedVecType || !distributedVecType)
2529 return op->emitOpError("expected vector type for distributed operands.");
2530 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2531 expandedVecType.getElementType() != distributedVecType.getElementType())
2532 return op->emitOpError(
2533 "expected distributed vectors to have same rank and element type.");
2534
2535 SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2536 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2537 int64_t eDim = expandedVecType.getDimSize(i);
2538 int64_t dDim = distributedVecType.getDimSize(i);
2539 if (eDim == dDim)
2540 continue;
2541 if (eDim % dDim != 0)
2542 return op->emitOpError()
2543 << "expected expanded vector dimension #" << i << " (" << eDim
2544 << ") to be a multipler of the distributed vector dimension ("
2545 << dDim << ")";
2546 scales[i] = eDim / dDim;
2547 }
2548 if (llvm::product_of(scales) != warpSize)
2549 return op->emitOpError()
2550 << "incompatible distribution dimensions from " << expandedVecType
2551 << " to " << distributedVecType << " with warp size = " << warpSize;
2552
2553 return success();
2554}
2555
2556LogicalResult WarpExecuteOnLane0Op::verify() {
2557 if (getArgs().size() != getWarpRegion().getNumArguments())
2558 return emitOpError(
2559 "expected same number op arguments and block arguments.");
2560 auto yield = dyn_cast<gpu::YieldOp>(getBody()->getTerminator());
2561 if (!yield)
2562 return emitOpError("expected body to be terminated with 'gpu.yield'");
2563 if (yield.getNumOperands() != getNumResults())
2564 return emitOpError(
2565 "expected same number of yield operands and return values.");
2566 int64_t warpSize = getWarpSize();
2567 for (auto [regionArg, arg] :
2568 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2569 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2570 warpSize, getOperation())))
2571 return failure();
2572 }
2573 for (auto [yieldOperand, result] :
2574 llvm::zip_equal(yield.getOperands(), getResults())) {
2575 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2576 warpSize, getOperation())))
2577 return failure();
2578 }
2579 return success();
2580}
2581bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2582 return succeeded(
2583 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2584}
2585
2586gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2587 return cast<gpu::YieldOp>(getBody()->getTerminator());
2588}
2589
2590//===----------------------------------------------------------------------===//
2591// GPU_SubgroupBroadcastOp
2592//===----------------------------------------------------------------------===//
2593
2594void gpu::SubgroupBroadcastOp::inferResultRanges(
2595 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
2596 setResultRange(getResult(), argRanges.front());
2597}
2598
2599Speculation::Speculatability gpu::SubgroupBroadcastOp::getSpeculatability() {
2600 switch (getBroadcastType()) {
2601 case BroadcastType::first_active_lane:
2602 // Cannot speculate first_lane broadcast, because speculating it across
2603 // control flow can change the active lanes.
2605 case BroadcastType::specific_lane:
2606 // Speculation should be safe as long as we inside structured control flow.
2608 }
2609 llvm_unreachable("Unknown BroadcastType");
2610}
2611
2612LogicalResult gpu::SubgroupBroadcastOp::verify() {
2613 switch (getBroadcastType()) {
2614 case BroadcastType::first_active_lane:
2615 if (getLane())
2616 return emitOpError()
2617 << "lane can only be specified for `specific_lane` broadcast";
2618 return success();
2619 case BroadcastType::specific_lane:
2620 if (!getLane())
2621 return emitOpError()
2622 << "lane must be specified for `specific_lane` broadcast";
2623 return success();
2624 }
2625 llvm_unreachable("Unknown BroadcastType");
2626}
2627
2628OpFoldResult gpu::SubgroupBroadcastOp::fold(FoldAdaptor /*adaptor*/) {
2629 // Broadcast result is always uniform.
2630 if (auto prev = getSrc().getDefiningOp<SubgroupBroadcastOp>())
2631 return prev.getResult();
2632
2633 return nullptr;
2634}
2635
2636//===----------------------------------------------------------------------===//
2637// GPU_BallotOp
2638//===----------------------------------------------------------------------===//
2639
2640// No custom implementations needed; ballot uses default behavior from ODS.
2641
2642//===----------------------------------------------------------------------===//
2643// GPU KernelMetadataAttr
2644//===----------------------------------------------------------------------===//
2645
2646KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2647 DictionaryAttr metadata) {
2648 assert(kernel && "invalid kernel");
2649 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2650 kernel.getAllArgAttrs(), metadata);
2651}
2652
2653KernelMetadataAttr
2654KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2655 FunctionOpInterface kernel,
2656 DictionaryAttr metadata) {
2657 assert(kernel && "invalid kernel");
2658 return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2659 kernel.getAllArgAttrs(), metadata);
2660}
2661
2662KernelMetadataAttr
2663KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2664 if (attrs.empty())
2665 return *this;
2666 NamedAttrList attrList;
2667 if (DictionaryAttr dict = getMetadata())
2668 attrList.append(dict);
2669 attrList.append(attrs);
2670 return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2671 attrList.getDictionary(getContext()));
2672}
2673
2674LogicalResult
2675KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2676 StringAttr name, Type functionType,
2677 ArrayAttr argAttrs, DictionaryAttr metadata) {
2678 if (name.empty())
2679 return emitError() << "the kernel name can't be empty";
2680 if (argAttrs) {
2681 if (llvm::any_of(argAttrs, [](Attribute attr) {
2682 return !llvm::isa<DictionaryAttr>(attr);
2683 }))
2684 return emitError()
2685 << "all attributes in the array must be a dictionary attribute";
2686 }
2687 return success();
2688}
2689
2690//===----------------------------------------------------------------------===//
2691// GPU KernelTableAttr
2692//===----------------------------------------------------------------------===//
2693
2694KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2695 ArrayRef<KernelMetadataAttr> kernels,
2696 bool isSorted) {
2697 // Note that `is_sorted` is always only invoked once even with assertions ON.
2698 assert((!isSorted || llvm::is_sorted(kernels)) &&
2699 "expected a sorted kernel array");
2700 // Immediately return the attribute if the array is sorted.
2701 if (isSorted || llvm::is_sorted(kernels))
2702 return Base::get(context, kernels);
2703 // Sort the array.
2704 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2705 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2706 return Base::get(context, kernelsTmp);
2707}
2708
2709KernelTableAttr KernelTableAttr::getChecked(
2710 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2711 ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2712 // Note that `is_sorted` is always only invoked once even with assertions ON.
2713 assert((!isSorted || llvm::is_sorted(kernels)) &&
2714 "expected a sorted kernel array");
2715 // Immediately return the attribute if the array is sorted.
2716 if (isSorted || llvm::is_sorted(kernels))
2717 return Base::getChecked(emitError, context, kernels);
2718 // Sort the array.
2719 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2720 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2721 return Base::getChecked(emitError, context, kernelsTmp);
2722}
2723
2724LogicalResult
2725KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2726 ArrayRef<KernelMetadataAttr> kernels) {
2727 if (kernels.size() < 2)
2728 return success();
2729 // Check that the kernels are uniquely named.
2730 if (std::adjacent_find(kernels.begin(), kernels.end(),
2731 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2732 return l.getName() == r.getName();
2733 }) != kernels.end()) {
2734 return emitError() << "expected all kernels to be uniquely named";
2735 }
2736 return success();
2737}
2738
2739KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2740 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2741 return found ? *iterator : KernelMetadataAttr();
2742}
2743
2744KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2745 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2746 return found ? *iterator : KernelMetadataAttr();
2747}
2748
2749//===----------------------------------------------------------------------===//
2750// GPU target options
2751//===----------------------------------------------------------------------===//
2752
2767
2785
2786TypeID TargetOptions::getTypeID() const { return typeID; }
2787
2788StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2789
2793
2794StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2795
2796StringRef TargetOptions::getELFSection() const { return elfSection; }
2797
2801
2802function_ref<void(llvm::Module &)>
2806
2807function_ref<void(llvm::Module &)>
2811
2812function_ref<void(llvm::Module &)>
2816
2818 return isaCallback;
2819}
2820
2821CompilationTarget TargetOptions::getCompilationTarget() const {
2822 return compilationTarget;
2823}
2824
2826 return CompilationTarget::Fatbin;
2827}
2828
2829std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2831 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2832 llvm::StringSaver stringSaver(options.first);
2833 StringRef opts = cmdOptions;
2834 // For a correct tokenization of the command line options `opts` must be
2835 // unquoted, otherwise the tokenization function returns a single string: the
2836 // unquoted `cmdOptions` -which is not the desired behavior.
2837 // Remove any quotes if they are at the beginning and end of the string:
2838 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2839 opts.consume_front("\""), opts.consume_back("\"");
2840 if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2841 opts.consume_front("'"), opts.consume_back("'");
2842#ifdef _WIN32
2843 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2844 /*MarkEOLs=*/false);
2845#else
2846 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2847 /*MarkEOLs=*/false);
2848#endif // _WIN32
2849 return options;
2850}
2851
2852std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2856
2857std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2859 size_t startPos = cmdOptions.find(startsWith);
2860 if (startPos == std::string::npos)
2861 return {llvm::BumpPtrAllocator(), SmallVector<const char *>()};
2862
2863 auto tokenized =
2864 tokenizeCmdOptions(cmdOptions.substr(startPos + startsWith.size()));
2865 cmdOptions.resize(startPos);
2866 return tokenized;
2867}
2868
2870
2871#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2872#include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2873
2874#define GET_ATTRDEF_CLASSES
2875#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2876
2877#define GET_OP_CLASSES
2878#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2879
2880#include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values, ArrayAttr attributes={})
static LogicalResult verifyDistributedType(Type expanded, Type distributed, int64_t warpSize, Operation *op)
Helper check if the distributed vector type is consistent with the expanded type and distributed size...
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices, StringRef keyword)
static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op, PatternRewriter &rewriter)
Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition TypeID.h:323
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
UnitAttr getUnitAttr()
Definition Builders.cpp:102
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:80
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition Builders.cpp:108
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:98
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:576
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:608
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool isParent() const
Returns true if branching from the parent op.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:78
bool isIndex() const
Definition Types.cpp:56
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
function_ref< void(llvm::Module &)> optimizedLlvmIRCallback
Callback invoked with LLVM IR for the device module after LLVM optimizations but before codegen.
function_ref< void(StringRef)> getISACallback() const
Returns the callback invoked with the target ISA for the device, for example PTX assembly.
TypeID getTypeID() const
Returns the typeID.
std::string toolkitPath
Path to the target toolkit.
SymbolTable * getSymbolTable() const
Returns the result of the getSymbolTableCallback callback or a nullptr if no callback was provided.
StringRef getELFSection() const
Returns the ELF section.
StringRef getCmdOptions() const
Returns the command line options.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
function_ref< void(StringRef)> isaCallback
Callback invoked with the target ISA for the device, for example PTX assembly.
CompilationTarget compilationTarget
Compilation process target format.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
function_ref< void(llvm::Module &)> initialLlvmIRCallback
Callback invoked with the initial LLVM IR for the device module.
ArrayRef< Attribute > getLibrariesToLink() const
Returns the LLVM libraries to link to.
TargetOptions(StringRef toolkitPath={}, ArrayRef< Attribute > librariesToLink={}, StringRef cmdOptions={}, StringRef elfSection={}, CompilationTarget compilationTarget=getDefaultCompilationTarget(), function_ref< SymbolTable *()> getSymbolTableCallback={}, function_ref< void(llvm::Module &)> initialLlvmIRCallback={}, function_ref< void(llvm::Module &)> linkedLlvmIRCallback={}, function_ref< void(llvm::Module &)> optimizedLlvmIRCallback={}, function_ref< void(StringRef)> isaCallback={})
Constructor initializing the toolkit path, the list of files to link to, extra command line options,...
function_ref< void(llvm::Module &)> getOptimizedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after LLVM optimizations but before c...
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
StringRef getToolkitPath() const
Returns the toolkit path.
SmallVector< Attribute > librariesToLink
List of files to link with the LLVM module.
function_ref< void(llvm::Module &)> linkedLlvmIRCallback
Callback invoked with LLVM IR for the device module after linking the device libraries.
function_ref< void(llvm::Module &)> getInitialLlvmIRCallback() const
Returns the callback invoked with the initial LLVM IR for the device module.
function_ref< SymbolTable *()> getSymbolTableCallback
Callback for obtaining the parent symbol table of all the GPU modules being serialized.
static CompilationTarget getDefaultCompilationTarget()
Returns the default compilation target: CompilationTarget::Fatbin.
function_ref< void(llvm::Module &)> getLinkedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after linking the device libraries.
std::string elfSection
ELF Section where the binary needs to be located.
CompilationTarget getCompilationTarget() const
Returns the compilation target.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
std::pair< IteratorT, bool > findAttrSorted(IteratorT first, IteratorT last, StringRef name)
Using llvm::lower_bound requires an extra string comparison to check whether the returned iterator po...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto getChecked(function_ref< InFlightDiagnostic()> emitError, MLIRContext *context, Ts &&...params)
Helper method analogous to get, but uses getChecked when available to allow graceful failure on inval...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition GPUDialect.h:39