MLIR 23.0.0git
GPUDialect.cpp
Go to the documentation of this file.
1//===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the GPU kernel-related dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
20#include "mlir/IR/Attributes.h"
21#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinOps.h"
25#include "mlir/IR/Diagnostics.h"
27#include "mlir/IR/Matchers.h"
30#include "mlir/IR/SymbolTable.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include "llvm/Support/CommandLine.h"
39#include "llvm/Support/ErrorHandling.h"
40#include "llvm/Support/FormatVariadic.h"
41#include "llvm/Support/InterleavedRange.h"
42#include "llvm/Support/StringSaver.h"
43#include <cassert>
44#include <numeric>
45#include <optional>
46
47using namespace mlir;
48using namespace mlir::gpu;
49
50#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
51
52//===----------------------------------------------------------------------===//
53// GPU Device Mapping Attributes
54//===----------------------------------------------------------------------===//
55
56int64_t GPUBlockMappingAttr::getMappingId() const {
57 return static_cast<int64_t>(getBlock());
58}
59
60bool GPUBlockMappingAttr::isLinearMapping() const {
61 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
62}
63
64int64_t GPUBlockMappingAttr::getRelativeIndex() const {
65 return isLinearMapping()
66 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
67 : getMappingId();
68}
69
70int64_t GPUWarpgroupMappingAttr::getMappingId() const {
71 return static_cast<int64_t>(getWarpgroup());
72}
73
74bool GPUWarpgroupMappingAttr::isLinearMapping() const {
75 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
76}
77
78int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
79 return isLinearMapping()
80 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
81 : getMappingId();
82}
83
84int64_t GPUWarpMappingAttr::getMappingId() const {
85 return static_cast<int64_t>(getWarp());
86}
87
88bool GPUWarpMappingAttr::isLinearMapping() const {
89 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
90}
91
92int64_t GPUWarpMappingAttr::getRelativeIndex() const {
93 return isLinearMapping()
94 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
95 : getMappingId();
96}
97
98int64_t GPUThreadMappingAttr::getMappingId() const {
99 return static_cast<int64_t>(getThread());
100}
101
102bool GPUThreadMappingAttr::isLinearMapping() const {
103 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
104}
105
106int64_t GPUThreadMappingAttr::getRelativeIndex() const {
107 return isLinearMapping()
108 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
109 : getMappingId();
110}
111
112int64_t GPULaneMappingAttr::getMappingId() const {
113 return static_cast<int64_t>(getLane());
114}
115
116bool GPULaneMappingAttr::isLinearMapping() const {
117 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
118}
119
120int64_t GPULaneMappingAttr::getRelativeIndex() const {
121 return isLinearMapping()
122 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
123 : getMappingId();
124}
125
126int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds() const { return 64; }
127
128/// 8 4 0
129/// Example mask : 0 0 0 1 1 0 1 0 0
130///
131/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
132/// Logical id for e.g. 5 (2) constructs filter (1 << 5 - 1).
133///
134/// Example mask : 0 0 0 1 1 0 1 0 0
135/// Example filter: 0 0 0 0 1 1 1 1 1
136/// Intersection : 0 0 0 0 1 0 1 0 0
137/// PopCnt : 2
138Value GPUMappingMaskAttr::createLogicalLinearMappingId(
139 OpBuilder &b, Value physicalLinearMappingId) const {
140 Location loc = physicalLinearMappingId.getLoc();
141 Value mask =
142 arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
143 Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
144 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
145 filter = arith::SubIOp::create(b, loc, filter, one);
146 Value filteredId = arith::AndIOp::create(b, loc, mask, filter);
147 return math::CtPopOp::create(b, loc, filteredId);
148}
149
150/// 8 4 0
151/// Example mask : 0 0 0 1 1 0 1 0 0
152///
153/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
154/// Logical id for e.g. 5 (2) constructs filter (1 << 5).
155///
156/// Example mask : 0 0 0 1 1 0 1 0 0
157/// Example filter: 0 0 0 1 0 0 0 0 0
158/// Intersection : 0 0 0 1 0 0 0 0 0
159/// Cmp : 1
160Value GPUMappingMaskAttr::createIsActiveIdPredicate(
161 OpBuilder &b, Value physicalLinearMappingId) const {
162 Location loc = physicalLinearMappingId.getLoc();
163 Value mask =
164 arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
165 Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
166 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
167 Value filtered = arith::AndIOp::create(b, loc, mask, filter);
168 Value zero = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(0));
169 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ne, filtered,
170 zero);
171}
172
173int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
174 return static_cast<int64_t>(getAddressSpace());
175}
176
177bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
178 llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
179}
180
181int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
182 llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
183}
184
185//===----------------------------------------------------------------------===//
186// MMAMatrixType
187//===----------------------------------------------------------------------===//
188
190 StringRef operand) {
191 return Base::get(elementType.getContext(), shape, elementType, operand);
192}
193
196 ArrayRef<int64_t> shape, Type elementType,
197 StringRef operand) {
198 return Base::getChecked(emitError, elementType.getContext(), shape,
199 elementType, operand);
200}
201
202unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
203
205 return getImpl()->getShape();
206}
207
208Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
209
210StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
211
213 return elementType.isF16() || elementType.isF32() || elementType.isF64() ||
214 elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
215 elementType.isInteger(32);
216}
217
218LogicalResult
220 ArrayRef<int64_t> shape, Type elementType,
221 StringRef operand) {
222 if (operand != "AOp" && operand != "BOp" && operand != "COp")
223 return emitError() << "operand expected to be one of AOp, BOp or COp";
224
225 if (shape.size() != 2)
226 return emitError() << "MMAMatrixType must have exactly two dimensions";
227
228 if (!MMAMatrixType::isValidElementType(elementType))
229 return emitError()
230 << "MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
231
232 return success();
233}
234
235//===----------------------------------------------------------------------===//
236// GPUDialect
237//===----------------------------------------------------------------------===//
238
239bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
240 if (!memorySpace)
241 return false;
242 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
243 return gpuAttr.getValue() == getWorkgroupAddressSpace();
244 return false;
245}
246
247bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
248 Attribute memorySpace = type.getMemorySpace();
249 return isWorkgroupMemoryAddressSpace(memorySpace);
250}
251
252bool GPUDialect::isConstantMemoryAddressSpace(Attribute memorySpace) {
253 if (!memorySpace)
254 return false;
255 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
256 return gpuAttr.getValue() == getConstantAddressSpace();
257 return false;
258}
259
260bool GPUDialect::hasConstantMemoryAddressSpace(MemRefType type) {
261 Attribute memorySpace = type.getMemorySpace();
262 return isConstantMemoryAddressSpace(memorySpace);
263}
264
265bool GPUDialect::isKernel(Operation *op) {
266 if (auto gpuFunc = dyn_cast<GPUFuncOp>(op))
267 return gpuFunc.isKernel();
268 return static_cast<bool>(
269 op->getAttrOfType<UnitAttr>(getKernelFuncAttrName()));
270}
271
272namespace {
273/// This class defines the interface for handling inlining with gpu
274/// operations.
275struct GPUInlinerInterface : public DialectInlinerInterface {
276 using DialectInlinerInterface::DialectInlinerInterface;
277
278 /// All gpu dialect ops can be inlined.
279 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
280 return true;
281 }
282};
283} // namespace
284
285void GPUDialect::initialize() {
286 addTypes<AsyncTokenType>();
287 addTypes<MMAMatrixType>();
288 addTypes<SparseDnTensorHandleType>();
289 addTypes<SparseSpMatHandleType>();
290 addTypes<SparseSpGEMMOpHandleType>();
291 addOperations<
292#define GET_OP_LIST
293#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
294 >();
295 addAttributes<
296#define GET_ATTRDEF_LIST
297#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
298 >();
299 addInterfaces<GPUInlinerInterface>();
300 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
301 TerminatorOp>();
302 declarePromisedInterfaces<ValueBoundsOpInterface, ClusterDimOp,
303 ClusterDimBlocksOp, ClusterIdOp, ClusterBlockIdOp,
304 BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp,
305 LaneIdOp, SubgroupIdOp, GlobalIdOp, NumSubgroupsOp,
306 SubgroupSizeOp, LaunchOp, SubgroupBroadcastOp>();
307 declarePromisedInterfaces<memref::IndexedAccessOpInterface,
308 SubgroupMmaLoadMatrixOp,
309 SubgroupMmaStoreMatrixOp>();
310}
311
312static std::string getSparseHandleKeyword(SparseHandleKind kind) {
313 switch (kind) {
315 return "sparse.dntensor_handle";
317 return "sparse.spmat_handle";
319 return "sparse.spgemmop_handle";
320 }
321 llvm_unreachable("unknown sparse handle kind");
322 return "";
323}
324
325Type GPUDialect::parseType(DialectAsmParser &parser) const {
326 // Parse the main keyword for the type.
327 StringRef keyword;
328 if (parser.parseKeyword(&keyword))
329 return Type();
330 MLIRContext *context = getContext();
331
332 // Handle 'async token' types.
333 if (keyword == "async.token")
334 return AsyncTokenType::get(context);
335
336 if (keyword == "mma_matrix") {
337 SMLoc beginLoc = parser.getNameLoc();
338
339 // Parse '<'.
340 if (parser.parseLess())
341 return nullptr;
342
343 // Parse the size and elementType.
344 SmallVector<int64_t> shape;
345 Type elementType;
346 if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
347 parser.parseType(elementType))
348 return nullptr;
349
350 // Parse ','
351 if (parser.parseComma())
352 return nullptr;
353
354 // Parse operand.
355 std::string operand;
356 if (failed(parser.parseOptionalString(&operand)))
357 return nullptr;
358
359 // Parse '>'.
360 if (parser.parseGreater())
361 return nullptr;
362
364 parser.getEncodedSourceLoc(beginLoc)),
365 shape, elementType, operand);
366 }
367
369 return SparseDnTensorHandleType::get(context);
371 return SparseSpMatHandleType::get(context);
373 return SparseSpGEMMOpHandleType::get(context);
374
375 parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
376 return Type();
377}
378// TODO: print refined type here. Notice that should be corresponding to the
379// parser
380void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
381 TypeSwitch<Type>(type)
382 .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
383 .Case<SparseDnTensorHandleType>([&](Type) {
385 })
386 .Case<SparseSpMatHandleType>(
388 .Case<SparseSpGEMMOpHandleType>([&](Type) {
390 })
391 .Case([&](MMAMatrixType fragTy) {
392 os << "mma_matrix<";
393 auto shape = fragTy.getShape();
394 for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
395 os << *dim << 'x';
396 os << shape.back() << 'x' << fragTy.getElementType();
397 os << ", \"" << fragTy.getOperand() << "\"" << '>';
398 })
399 .DefaultUnreachable("unexpected 'gpu' type kind");
400}
401
402static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
403 NamedAttribute attr) {
404 auto array = dyn_cast<DenseI32ArrayAttr>(attr.getValue());
405 if (!array)
406 return op->emitOpError(Twine(attr.getName()) +
407 " must be a dense i32 array");
408 if (array.size() != 3)
409 return op->emitOpError(Twine(attr.getName()) +
410 " must contain exactly 3 elements");
411 return success();
412}
413
414LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
415 NamedAttribute attr) {
416 if (attr.getName() == getKnownBlockSizeAttrHelper().getName())
417 return verifyKnownLaunchSizeAttr(op, attr);
418 if (attr.getName() == getKnownGridSizeAttrHelper().getName())
419 return verifyKnownLaunchSizeAttr(op, attr);
420 if (attr.getName() == getKnownClusterSizeAttrHelper().getName())
421 return verifyKnownLaunchSizeAttr(op, attr);
422 if (!llvm::isa<UnitAttr>(attr.getValue()) ||
423 attr.getName() != getContainerModuleAttrName())
424 return success();
425
426 auto module = dyn_cast<ModuleOp>(op);
427 if (!module)
428 return op->emitError("expected '")
429 << getContainerModuleAttrName() << "' attribute to be attached to '"
430 << ModuleOp::getOperationName() << '\'';
431 return success();
432}
433
434/// Parses an optional list of async operands with an optional leading keyword.
435/// (`async`)? (`[` ssa-id-list `]`)?
436///
437/// This method is used by the tablegen assembly format for async ops as well.
438static ParseResult parseAsyncDependencies(
439 OpAsmParser &parser, Type &asyncTokenType,
441 auto loc = parser.getCurrentLocation();
442 if (succeeded(parser.parseOptionalKeyword("async"))) {
443 if (parser.getNumResults() == 0)
444 return parser.emitError(loc, "needs to be named when marked 'async'");
445 asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
446 }
447 return parser.parseOperandList(asyncDependencies,
449}
450
451/// Prints optional async dependencies with its leading keyword.
452/// (`async`)? (`[` ssa-id-list `]`)?
453// Used by the tablegen assembly format for several async ops.
455 Type asyncTokenType,
456 OperandRange asyncDependencies) {
457 if (asyncTokenType)
458 printer << "async";
459 if (asyncDependencies.empty())
460 return;
461 if (asyncTokenType)
462 printer << ' ';
463 printer << llvm::interleaved_array(asyncDependencies);
464}
465
466// GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
467/// Parses a GPU function memory attribution.
468///
469/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
470/// (`private` `(` ssa-id-and-type-list `)`)?
471///
472/// Note that this function parses only one of the two similar parts, with the
473/// keyword provided as argument.
474static ParseResult
475parseAttributions(OpAsmParser &parser, StringRef keyword,
477 // If we could not parse the keyword, just assume empty list and succeed.
478 if (failed(parser.parseOptionalKeyword(keyword)))
479 return success();
480
482 /*allowType=*/true);
483}
484
485static void printAttributions(OpAsmPrinter &p, StringRef keyword,
487 ArrayAttr attributes = {}) {
488 if (values.empty())
489 return;
490
491 p << ' ' << keyword << '(';
492 llvm::interleaveComma(
493 llvm::enumerate(values), p, [&p, attributes](auto pair) {
494 BlockArgument v = pair.value();
495 p << v << " : " << v.getType();
496
497 size_t attributionIndex = pair.index();
498 DictionaryAttr attrs;
499 if (attributes && attributionIndex < attributes.size())
500 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
501 if (attrs)
502 p.printOptionalAttrDict(attrs.getValue());
503 });
504 p << ')';
505}
506
507/// Verifies a GPU function memory attribution.
508static LogicalResult verifyAttributions(Operation *op,
509 ArrayRef<BlockArgument> attributions,
510 gpu::AddressSpace memorySpace) {
511 for (Value v : attributions) {
512 auto type = llvm::dyn_cast<MemRefType>(v.getType());
513 if (!type)
514 return op->emitOpError() << "expected memref type in attribution";
515
516 // We can only verify the address space if it hasn't already been lowered
517 // from the AddressSpaceAttr to a target-specific numeric value.
518 auto addressSpace =
519 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
520 if (!addressSpace)
521 continue;
522 if (addressSpace.getValue() != memorySpace)
523 return op->emitOpError()
524 << "expected memory space " << stringifyAddressSpace(memorySpace)
525 << " in attribution";
526 }
527 return success();
528}
529
530//===----------------------------------------------------------------------===//
531// AllReduceOp
532//===----------------------------------------------------------------------===//
533
534static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
535 Type resType) {
536 using Kind = gpu::AllReduceOperation;
537 if (llvm::is_contained(
538 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
539 opName)) {
540 if (!isa<FloatType>(resType))
541 return failure();
542 }
543
544 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
545 Kind::AND, Kind::OR, Kind::XOR},
546 opName)) {
547 if (!isa<IntegerType>(resType))
548 return failure();
549 }
550
551 return success();
552}
553
554LogicalResult gpu::AllReduceOp::verifyRegions() {
555 if (getBody().empty() != getOp().has_value())
556 return emitError("expected either an op attribute or a non-empty body");
557 if (!getBody().empty()) {
558 if (getBody().getNumArguments() != 2)
559 return emitError("expected two region arguments");
560 for (auto argument : getBody().getArguments()) {
561 if (argument.getType() != getType())
562 return emitError("incorrect region argument type");
563 }
564 unsigned yieldCount = 0;
565 for (Block &block : getBody()) {
566 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
567 if (yield.getNumOperands() != 1)
568 return emitError("expected one gpu.yield operand");
569 if (yield.getOperand(0).getType() != getType())
570 return emitError("incorrect gpu.yield type");
571 ++yieldCount;
572 }
573 }
574 if (yieldCount == 0)
575 return emitError("expected gpu.yield op in region");
576 } else {
577 gpu::AllReduceOperation opName = *getOp();
578 if (failed(verifyReduceOpAndType(opName, getType()))) {
579 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
580 << "` reduction operation is not compatible with type "
581 << getType();
582 }
583 }
584
585 return success();
586}
587
589 auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
590 if (!launchOp)
591 return false;
592
593 Region &body = launchOp.getBody();
594 assert(!body.empty() && "Invalid region");
595
596 // Only convert ops in gpu::launch entry block for now.
597 return op->getBlock() == &body.front();
598}
599
600OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
601 if (!getUniform() && canMakeGroupOpUniform(*this)) {
602 setUniform(true);
603 return getResult();
604 }
605
606 return nullptr;
607}
608
609// TODO: Support optional custom attributes (without dialect prefix).
610static ParseResult parseAllReduceOperation(AsmParser &parser,
611 AllReduceOperationAttr &attr) {
612 StringRef enumStr;
613 if (!parser.parseOptionalKeyword(&enumStr)) {
614 std::optional<AllReduceOperation> op =
615 gpu::symbolizeAllReduceOperation(enumStr);
616 if (!op)
617 return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
618 attr = AllReduceOperationAttr::get(parser.getContext(), *op);
619 }
620 return success();
621}
622
624 AllReduceOperationAttr attr) {
625 if (attr)
626 attr.print(printer);
627}
628
629//===----------------------------------------------------------------------===//
630// SubgroupReduceOp
631//===----------------------------------------------------------------------===//
632
633LogicalResult gpu::SubgroupReduceOp::verify() {
634 Type elemType = getType();
635 if (auto vecTy = dyn_cast<VectorType>(elemType)) {
636 if (vecTy.isScalable())
637 return emitOpError() << "is not compatible with scalable vector types";
638
639 elemType = vecTy.getElementType();
640 }
641
642 gpu::AllReduceOperation opName = getOp();
643 if (failed(verifyReduceOpAndType(opName, elemType))) {
644 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
645 << "` reduction operation is not compatible with type "
646 << getType();
647 }
648
649 auto clusterSize = getClusterSize();
650 if (clusterSize) {
651 uint32_t size = *clusterSize;
652 if (!llvm::isPowerOf2_32(size)) {
653 return emitOpError() << "cluster size " << size
654 << " is not a power of two";
655 }
656 }
657
658 uint32_t stride = getClusterStride();
659 if (stride != 1 && !clusterSize) {
660 return emitOpError() << "cluster stride can only be specified if cluster "
661 "size is specified";
662 }
663 if (!llvm::isPowerOf2_32(stride)) {
664 return emitOpError() << "cluster stride " << stride
665 << " is not a power of two";
666 }
667
668 return success();
669}
670
671OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
672 if (getClusterSize() == 1)
673 return getValue();
674
675 if (!getUniform() && canMakeGroupOpUniform(*this)) {
676 setUniform(true);
677 return getResult();
678 }
679
680 return nullptr;
681}
682
683//===----------------------------------------------------------------------===//
684// AsyncOpInterface
685//===----------------------------------------------------------------------===//
686
688 op->insertOperands(0, {token});
689 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
690 return;
691 auto attrName =
693 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
694
695 // Async dependencies is the only variadic operand.
696 if (!sizeAttr)
697 return;
698
699 SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
700 ++sizes.front();
701 op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
702}
703
704//===----------------------------------------------------------------------===//
705// LaunchOp
706//===----------------------------------------------------------------------===//
707
708void LaunchOp::build(OpBuilder &builder, OperationState &result,
709 Value gridSizeX, Value gridSizeY, Value gridSizeZ,
710 Value getBlockSizeX, Value getBlockSizeY,
711 Value getBlockSizeZ, Value dynamicSharedMemorySize,
712 Type asyncTokenType, ValueRange asyncDependencies,
713 TypeRange workgroupAttributions,
714 TypeRange privateAttributions, Value clusterSizeX,
715 Value clusterSizeY, Value clusterSizeZ,
716 FlatSymbolRefAttr module, FlatSymbolRefAttr function) {
717 OpBuilder::InsertionGuard g(builder);
718
719 if (!workgroupAttributions.empty())
720 result.addAttribute(
721 getWorkgroupAttributionsAttrName(result.name),
722 builder.getI64IntegerAttr(workgroupAttributions.size()));
723
724 // Add Op operands.
725 result.addOperands(asyncDependencies);
726 if (asyncTokenType)
727 result.types.push_back(builder.getType<AsyncTokenType>());
728
729 // Add grid and block sizes as op operands, followed by the data operands.
730 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
731 getBlockSizeY, getBlockSizeZ});
732 if (clusterSizeX)
733 result.addOperands(clusterSizeX);
734 if (clusterSizeY)
735 result.addOperands(clusterSizeY);
736 if (clusterSizeZ)
737 result.addOperands(clusterSizeZ);
738 if (dynamicSharedMemorySize)
739 result.addOperands(dynamicSharedMemorySize);
740
741 // Add optional module and function attributes.
742 if (module)
743 result.addAttribute(getModuleAttrName(result.name), module);
744 if (function)
745 result.addAttribute(getFunctionAttrName(result.name), function);
746
747 // Create a kernel body region with kNumConfigRegionAttributes + N memory
748 // attributions, where the first kNumConfigRegionAttributes arguments have
749 // `index` type and the rest have the same types as the data operands.
750 Region *kernelRegion = result.addRegion();
751 Block *body = builder.createBlock(kernelRegion);
752 // TODO: Allow passing in proper locations here.
753 for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
754 body->addArgument(builder.getIndexType(), result.location);
755 // Add WorkGroup & Private attributions to the region arguments.
756 for (Type argTy : workgroupAttributions)
757 body->addArgument(argTy, result.location);
758 for (Type argTy : privateAttributions)
759 body->addArgument(argTy, result.location);
760 // Fill OperandSegmentSize Attribute.
761 SmallVector<int32_t, 11> segmentSizes(11, 1);
762 segmentSizes.front() = asyncDependencies.size();
763 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
764 segmentSizes[7] = clusterSizeX ? 1 : 0;
765 segmentSizes[8] = clusterSizeY ? 1 : 0;
766 segmentSizes[9] = clusterSizeZ ? 1 : 0;
767 result.addAttribute(getOperandSegmentSizeAttr(),
768 builder.getDenseI32ArrayAttr(segmentSizes));
769}
770
771KernelDim3 LaunchOp::getBlockIds() {
772 assert(!getBody().empty() && "LaunchOp body must not be empty.");
773 auto args = getBody().getArguments();
774 return KernelDim3{args[0], args[1], args[2]};
775}
776
777KernelDim3 LaunchOp::getThreadIds() {
778 assert(!getBody().empty() && "LaunchOp body must not be empty.");
779 auto args = getBody().getArguments();
780 return KernelDim3{args[3], args[4], args[5]};
781}
782
783KernelDim3 LaunchOp::getGridSize() {
784 assert(!getBody().empty() && "LaunchOp body must not be empty.");
785 auto args = getBody().getArguments();
786 return KernelDim3{args[6], args[7], args[8]};
787}
788
789KernelDim3 LaunchOp::getBlockSize() {
790 assert(!getBody().empty() && "LaunchOp body must not be empty.");
791 auto args = getBody().getArguments();
792 return KernelDim3{args[9], args[10], args[11]};
793}
794
795std::optional<KernelDim3> LaunchOp::getClusterIds() {
796 assert(!getBody().empty() && "LaunchOp body must not be empty.");
797 if (!hasClusterSize())
798 return std::nullopt;
799 auto args = getBody().getArguments();
800 return KernelDim3{args[12], args[13], args[14]};
801}
802
803std::optional<KernelDim3> LaunchOp::getClusterSize() {
804 assert(!getBody().empty() && "LaunchOp body must not be empty.");
805 if (!hasClusterSize())
806 return std::nullopt;
807 auto args = getBody().getArguments();
808 return KernelDim3{args[15], args[16], args[17]};
809}
810
811KernelDim3 LaunchOp::getGridSizeOperandValues() {
812 auto operands = getOperands().drop_front(getAsyncDependencies().size());
813 return KernelDim3{operands[0], operands[1], operands[2]};
814}
815
816KernelDim3 LaunchOp::getBlockSizeOperandValues() {
817 auto operands = getOperands().drop_front(getAsyncDependencies().size());
818 return KernelDim3{operands[3], operands[4], operands[5]};
819}
820
821std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
822 auto operands = getOperands().drop_front(getAsyncDependencies().size());
823 if (!hasClusterSize())
824 return std::nullopt;
825 return KernelDim3{operands[6], operands[7], operands[8]};
826}
827
828LogicalResult LaunchOp::verify() {
829 if (!(hasClusterSize()) &&
830 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
831 return emitOpError() << "cluster size must be all present";
832 return success();
833}
834
835LogicalResult LaunchOp::verifyRegions() {
836 // Kernel launch takes kNumConfigOperands leading operands for grid/block
837 // sizes and transforms them into kNumConfigRegionAttributes region arguments
838 // for block/thread identifiers and grid/block sizes.
839 if (getBody().empty()) {
840 return emitOpError("body region is empty");
841 }
842 if (getBody().getNumArguments() <
843 kNumConfigRegionAttributes + getNumWorkgroupAttributions()) {
844 return emitOpError("unexpected number of region arguments");
845 }
846
847 // Verify Attributions Address Spaces.
848 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributionBBArgs(),
849 GPUDialect::getWorkgroupAddressSpace())) ||
850 failed(verifyAttributions(getOperation(), getPrivateAttributions(),
851 GPUDialect::getPrivateAddressSpace())))
852 return failure();
853
854 // Block terminators without successors are expected to exit the kernel region
855 // and must be `gpu.terminator`.
856 for (Block &block : getBody()) {
857 if (block.empty())
858 continue;
859 if (block.back().getNumSuccessors() != 0)
860 continue;
861 if (!isa<gpu::TerminatorOp>(&block.back())) {
862 return block.back()
863 .emitError()
864 .append("expected '", gpu::TerminatorOp::getOperationName(),
865 "' or a terminator with successors")
866 .attachNote(getLoc())
867 .append("in '", LaunchOp::getOperationName(), "' body region");
868 }
869 }
870
871 if (getNumResults() == 0 && getAsyncToken())
872 return emitOpError("needs to be named when async keyword is specified");
873
874 return success();
875}
876
877// Pretty-print the kernel grid/block size assignment as
878// (%iter-x, %iter-y, %iter-z) in
879// (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
880// where %size-* and %iter-* will correspond to the body region arguments.
882 KernelDim3 operands, KernelDim3 ids) {
883 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
884 p << size.x << " = " << operands.x << ", ";
885 p << size.y << " = " << operands.y << ", ";
886 p << size.z << " = " << operands.z << ')';
887}
888
889void LaunchOp::print(OpAsmPrinter &p) {
890 if (getAsyncToken()) {
891 p << " async";
892 if (!getAsyncDependencies().empty())
893 p << " [" << getAsyncDependencies() << ']';
894 }
895 // Print the launch configuration.
896 if (hasClusterSize()) {
897 p << ' ' << getClustersKeyword();
898 printSizeAssignment(p, getClusterSize().value(),
899 getClusterSizeOperandValues().value(),
900 getClusterIds().value());
901 }
902 p << ' ' << getBlocksKeyword();
903 printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
904 getBlockIds());
905 p << ' ' << getThreadsKeyword();
906 printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
907 getThreadIds());
908 if (getDynamicSharedMemorySize())
909 p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
910 << getDynamicSharedMemorySize();
911
912 // Print optional module attribute.
913 StringRef moduleAttrName = getModuleAttrName();
914 if (auto module = getModule()) {
915 p << ' ' << moduleAttrName << '(';
916 p.printSymbolName(*module);
917 p << ')';
918 }
919 // Print optional function attribute.
920 StringRef functionAttrName = getFunctionAttrName();
921 if (auto function = getFunction()) {
922 p << ' ' << functionAttrName << '(';
923 p.printSymbolName(*function);
924 p << ')';
925 }
926
927 if (getCooperative())
928 p << " cooperative";
929
930 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributionBBArgs());
931 printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
932
933 p << ' ';
934
935 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
936 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
937 LaunchOp::getOperandSegmentSizeAttr(),
938 getWorkgroupAttributionsAttrName(),
939 getCooperativeAttrName(), moduleAttrName,
940 functionAttrName});
941}
942
943// Parse the size assignment blocks for blocks and threads. These have the form
944// (%region_arg, %region_arg, %region_arg) in
945// (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
946// where %region_arg are percent-identifiers for the region arguments to be
947// introduced further (SSA defs), and %operand are percent-identifiers for the
948// SSA value uses.
949static ParseResult
954 StringRef keyword) {
955 assert(indices.size() == 3 && "space for three indices expected");
958 /*allowResultNumber=*/false) ||
959 parser.parseKeyword("in") || parser.parseLParen())
960 return failure();
961
962 if (args.size() != 3) {
963 return parser.emitError(parser.getNameLoc())
964 << keyword << " expects 3 arguments, but got " << args.size();
965 }
966 std::move(args.begin(), args.end(), indices.begin());
967
968 for (int i = 0; i < 3; ++i) {
969 if (i != 0 && parser.parseComma())
970 return failure();
971 if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
972 parser.parseEqual() || parser.parseOperand(sizes[i]))
973 return failure();
974 }
975
976 return parser.parseRParen();
977}
978
979/// Parses a Launch operation.
980/// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
981/// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
982/// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
983/// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
984/// (`dynamic_shared_memory_size` ssa-use)?
985/// (`module(` symbol-ref-id `)`)?
986/// (`function(` symbol-ref-id `)`)?
987/// memory-attribution
988/// region attr-dict?
989/// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
990ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
991 // Sizes of the grid and block.
992 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
993 sizes(LaunchOp::kNumConfigOperands);
994
995 // Region arguments to be created.
996 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
997 LaunchOp::kNumConfigRegionAttributes);
998
999 // Parse optional async dependencies.
1000 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
1001 Type asyncTokenType;
1002 if (failed(
1003 parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
1004 parser.resolveOperands(asyncDependencies, asyncTokenType,
1005 result.operands))
1006 return failure();
1007 if (parser.getNumResults() > 0) {
1008 if (!asyncTokenType)
1009 return parser.emitError(
1010 parser.getNameLoc(),
1011 "gpu.launch requires 'async' keyword to return a value");
1012 result.types.push_back(asyncTokenType);
1013 }
1014
1015 bool hasCluster = false;
1016 if (succeeded(parser.parseOptionalKeyword(LaunchOp::getClustersKeyword()))) {
1017 hasCluster = true;
1018 sizes.resize(9);
1019 regionArgs.resize(18);
1020 }
1021 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
1022 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
1023
1024 // Last three segment assigns the cluster size. In the region argument
1025 // list, this is last 6 arguments.
1026 if (hasCluster) {
1028 parser, sizesRef.drop_front(6), regionArgsRef.slice(15, 3),
1029 regionArgsRef.slice(12, 3), LaunchOp::getClustersKeyword()))
1030 return failure();
1031 }
1032 // Parse the size assignment segments: the first segment assigns grid sizes
1033 // and defines values for block identifiers; the second segment assigns block
1034 // sizes and defines values for thread identifiers. In the region argument
1035 // list, identifiers precede sizes, and block-related values precede
1036 // thread-related values.
1037 if (parser.parseKeyword(LaunchOp::getBlocksKeyword()) ||
1038 parseSizeAssignment(parser, sizesRef.take_front(3),
1039 regionArgsRef.slice(6, 3), regionArgsRef.slice(0, 3),
1040 LaunchOp::getBlocksKeyword()) ||
1041 parser.parseKeyword(LaunchOp::getThreadsKeyword()) ||
1042 parseSizeAssignment(parser, sizesRef.drop_front(3),
1043 regionArgsRef.slice(9, 3), regionArgsRef.slice(3, 3),
1044 LaunchOp::getThreadsKeyword()) ||
1045 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
1046 result.operands))
1047 return failure();
1048
1049 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
1050 bool hasDynamicSharedMemorySize = false;
1051 if (!parser.parseOptionalKeyword(
1052 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1053 hasDynamicSharedMemorySize = true;
1054 if (parser.parseOperand(dynamicSharedMemorySize) ||
1055 parser.resolveOperand(dynamicSharedMemorySize,
1056 parser.getBuilder().getI32Type(),
1057 result.operands))
1058 return failure();
1059 }
1060
1061 // Parse optional module attribute.
1062 StringRef moduleAttrName = getModuleAttrName(result.name);
1063 if (succeeded(parser.parseOptionalKeyword(moduleAttrName))) {
1064 FlatSymbolRefAttr moduleSymbol;
1065 if (parser.parseLParen() ||
1066 parser.parseAttribute(moduleSymbol, Type(), moduleAttrName,
1067 result.attributes) ||
1068 parser.parseRParen())
1069 return failure();
1070 }
1071 // Parse optional function attribute.
1072 StringRef functionAttrName = getFunctionAttrName(result.name);
1073 if (succeeded(parser.parseOptionalKeyword(functionAttrName))) {
1074 FlatSymbolRefAttr funcSymbol;
1075 if (parser.parseLParen() ||
1076 parser.parseAttribute(funcSymbol, Type(), functionAttrName,
1077 result.attributes) ||
1078 parser.parseRParen())
1079 return failure();
1080 }
1081
1082 // Parse optional cooperative keyword.
1083 if (succeeded(parser.parseOptionalKeyword("cooperative")))
1084 result.addAttribute("cooperative", parser.getBuilder().getUnitAttr());
1085
1086 // Create the region arguments: fixed launch-config args (`index`), then
1087 // workgroup / private attribution args. The workgroup count is stored in the
1088 // inherent `workgroup_attributions` attribute when non-zero.
1089 Type index = parser.getBuilder().getIndexType();
1090 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
1091 LaunchOp::kNumConfigRegionAttributes + 6, index);
1092
1093 SmallVector<OpAsmParser::Argument> regionArguments;
1094 for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1095 OpAsmParser::Argument arg;
1096 arg.ssaName = std::get<0>(ssaValueAndType);
1097 arg.type = std::get<1>(ssaValueAndType);
1098 regionArguments.push_back(arg);
1099 }
1100
1101 Builder &builder = parser.getBuilder();
1102 // Parse workgroup memory attributions.
1103 if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
1104 regionArguments)))
1105 return failure();
1106
1107 // Store the number of operands we just parsed as the number of workgroup
1108 // memory attributions.
1109 unsigned numWorkgroupAttrs = regionArguments.size() -
1110 LaunchOp::kNumConfigRegionAttributes -
1111 (hasCluster ? 6 : 0);
1112 if (numWorkgroupAttrs != 0)
1113 result.addAttribute(LaunchOp::getWorkgroupAttributionsAttrName(result.name),
1114 builder.getI64IntegerAttr(numWorkgroupAttrs));
1115
1116 // Parse private memory attributions.
1117 if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
1118 regionArguments)))
1119 return failure();
1120
1121 // Introduce the body region and parse it. The region has
1122 // kNumConfigRegionAttributes arguments that correspond to
1123 // block/thread identifiers and grid/block sizes, all having `index` type.
1124 Region *body = result.addRegion();
1125 if (parser.parseRegion(*body, regionArguments) ||
1126 parser.parseOptionalAttrDict(result.attributes))
1127 return failure();
1128
1129 SmallVector<int32_t, 11> segmentSizes(11, 1);
1130 segmentSizes.front() = asyncDependencies.size();
1131
1132 if (!hasCluster) {
1133 segmentSizes[7] = 0;
1134 segmentSizes[8] = 0;
1135 segmentSizes[9] = 0;
1136 }
1137 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1138 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1139 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1140 return success();
1141}
1142
1143/// Simplify the gpu.launch when the range of a thread or block ID is
1144/// trivially known to be one.
1145struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1146 using OpRewritePattern<LaunchOp>::OpRewritePattern;
1147 LogicalResult matchAndRewrite(LaunchOp op,
1148 PatternRewriter &rewriter) const override {
1149 // If the range implies a single value for `id`, replace `id`'s uses by
1150 // zero.
1151 Value zero;
1152 bool simplified = false;
1153 auto constPropIdUses = [&](Value id, Value size) {
1154 // Check if size is trivially one.
1155 if (!matchPattern(size, m_One()))
1156 return;
1157 if (id.getUses().empty())
1158 return;
1159 if (!simplified) {
1160 // Create a zero value the first time.
1161 OpBuilder::InsertionGuard guard(rewriter);
1162 rewriter.setInsertionPointToStart(&op.getBody().front());
1163 zero =
1164 arith::ConstantIndexOp::create(rewriter, op.getLoc(), /*value=*/0);
1165 }
1166 rewriter.replaceAllUsesWith(id, zero);
1167 simplified = true;
1168 };
1169 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1170 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1171 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1172 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1173 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1174 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1175
1176 return success(simplified);
1177 }
1178};
1179
1180void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1181 MLIRContext *context) {
1182 rewrites.add<FoldLaunchArguments>(context);
1183}
1184
1185/// Adds a new block argument that corresponds to buffers located in
1186/// workgroup memory.
1187BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1188 int64_t cur = getWorkgroupAttributions().value_or(0);
1189 setWorkgroupAttributions(std::optional<int64_t>(cur + 1));
1190 return getBody().insertArgument(
1191 getNumConfigRegionAttributes() + static_cast<unsigned>(cur), type, loc);
1192}
1193
1194/// Adds a new block argument that corresponds to buffers located in
1195/// private memory.
1196BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1197 // Buffers on the private memory always come after buffers on the workgroup
1198 // memory.
1199 return getBody().addArgument(type, loc);
1200}
1201
1202//===----------------------------------------------------------------------===//
1203// LaunchFuncOp
1204//===----------------------------------------------------------------------===//
1205
1206void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1207 SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1208 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1209 ValueRange kernelOperands, Type asyncTokenType,
1210 ValueRange asyncDependencies,
1211 std::optional<KernelDim3> clusterSize) {
1212 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1213 "expected a symbol reference with a single nested reference");
1214 result.addOperands(asyncDependencies);
1215 if (asyncTokenType)
1216 result.types.push_back(builder.getType<AsyncTokenType>());
1217
1218 // Add grid and block sizes as op operands, followed by the data operands.
1219 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1221 if (clusterSize.has_value())
1222 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1223 if (dynamicSharedMemorySize)
1224 result.addOperands(dynamicSharedMemorySize);
1225 result.addOperands(kernelOperands);
1226
1227 Properties &prop = result.getOrAddProperties<Properties>();
1228 prop.kernel = kernelSymbol;
1229 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1230 // Initialize the segment sizes to 1.
1231 llvm::fill(prop.operandSegmentSizes, 1);
1232 prop.operandSegmentSizes[0] = asyncDependencies.size();
1233 if (!clusterSize.has_value()) {
1234 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1235 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1236 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1237 }
1238 prop.operandSegmentSizes[segmentSizesLen - 3] =
1239 dynamicSharedMemorySize ? 1 : 0;
1240 prop.operandSegmentSizes[segmentSizesLen - 2] =
1241 static_cast<int32_t>(kernelOperands.size());
1242 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1243}
1244
1245void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1246 GPUFuncOp kernelFunc, KernelDim3 gridSize,
1247 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1248 ValueRange kernelOperands, Type asyncTokenType,
1249 ValueRange asyncDependencies,
1250 std::optional<KernelDim3> clusterSize) {
1251 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1252 auto kernelSymbol =
1253 SymbolRefAttr::get(kernelModule.getNameAttr(),
1254 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1255 build(builder, result, kernelSymbol, gridSize, getBlockSize,
1256 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1257 asyncDependencies, clusterSize);
1258}
1259
1260void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1261 SymbolRefAttr kernel, KernelDim3 gridSize,
1262 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1263 ValueRange kernelOperands, Value asyncObject,
1264 std::optional<KernelDim3> clusterSize) {
1265 // Add grid and block sizes as op operands, followed by the data operands.
1266 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1268 if (clusterSize.has_value())
1269 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1270 if (dynamicSharedMemorySize)
1271 result.addOperands(dynamicSharedMemorySize);
1272 result.addOperands(kernelOperands);
1273 if (asyncObject)
1274 result.addOperands(asyncObject);
1275 Properties &prop = result.getOrAddProperties<Properties>();
1276 prop.kernel = kernel;
1277 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1278 // Initialize the segment sizes to 1.
1279 llvm::fill(prop.operandSegmentSizes, 1);
1280 prop.operandSegmentSizes[0] = 0;
1281 if (!clusterSize.has_value()) {
1282 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1283 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1284 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1285 }
1286 prop.operandSegmentSizes[segmentSizesLen - 3] =
1287 dynamicSharedMemorySize ? 1 : 0;
1288 prop.operandSegmentSizes[segmentSizesLen - 2] =
1289 static_cast<int32_t>(kernelOperands.size());
1290 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1291}
1292
1293StringAttr LaunchFuncOp::getKernelModuleName() {
1294 return getKernel().getRootReference();
1295}
1296
1297StringAttr LaunchFuncOp::getKernelName() {
1298 return getKernel().getLeafReference();
1299}
1300
1301unsigned LaunchFuncOp::getNumKernelOperands() {
1302 return getKernelOperands().size();
1303}
1304
1305Value LaunchFuncOp::getKernelOperand(unsigned i) {
1306 return getKernelOperands()[i];
1307}
1308
1309KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1310 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1311 return KernelDim3{operands[0], operands[1], operands[2]};
1312}
1313
1314KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1315 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1316 return KernelDim3{operands[3], operands[4], operands[5]};
1317}
1318
1319KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1320 assert(hasClusterSize() &&
1321 "cluster size is not set, check hasClusterSize() first");
1322 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1323 return KernelDim3{operands[6], operands[7], operands[8]};
1324}
1325
1326LogicalResult LaunchFuncOp::verify() {
1327 auto module = (*this)->getParentOfType<ModuleOp>();
1328 if (!module)
1329 return emitOpError("expected to belong to a module");
1330
1331 if (!module->getAttrOfType<UnitAttr>(
1332 GPUDialect::getContainerModuleAttrName()))
1333 return emitOpError("expected the closest surrounding module to have the '" +
1334 GPUDialect::getContainerModuleAttrName() +
1335 "' attribute");
1336
1337 if (hasClusterSize()) {
1338 if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1339 getClusterSizeZ().getType() != getClusterSizeX().getType())
1340 return emitOpError()
1341 << "expects types of the cluster dimensions must be the same";
1342 }
1343
1344 if (!getAsyncDependencies().empty() && getAsyncObject())
1345 return emitOpError(
1346 "cannot have both async dependencies and an explicit async object");
1347
1348 return success();
1349}
1350
1351LogicalResult
1352LaunchFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1353 LaunchFuncOp launchOp = *this;
1354 Operation *table = SymbolTable::getNearestSymbolTable(launchOp);
1355 // GPU modules cannot be nested within each other, escape to resolve the name.
1356 if (isa<GPUModuleOp>(table))
1358
1359 // Ignore launches that are nested more or less deep than functions in the
1360 // module we are currently checking.
1361 if (!launchOp->getParentOp() ||
1362 launchOp->getParentOp()->getParentOp() != table)
1363 return success();
1364
1365 // Ignore launch ops with missing attributes here. The errors will be
1366 // reported by the verifiers of those ops.
1367 if (!launchOp->getAttrOfType<SymbolRefAttr>(
1368 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
1369 return success();
1370
1371 // Check that `launch_func` refers to a well-formed GPU kernel container.
1372 StringAttr kernelContainerName = launchOp.getKernelModuleName();
1373 Operation *kernelContainer =
1374 symbolTable.lookupNearestSymbolFrom(table, kernelContainerName);
1375 if (!kernelContainer)
1376 return launchOp.emitOpError()
1377 << "kernel container '" << kernelContainerName.getValue()
1378 << "' is undefined";
1379
1380 // If the container is a GPU binary op return success.
1381 if (isa<BinaryOp>(kernelContainer))
1382 return success();
1383
1384 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
1385 if (!kernelModule)
1386 return launchOp.emitOpError()
1387 << "kernel module '" << kernelContainerName.getValue()
1388 << "' is undefined";
1389
1390 // Check that `launch_func` refers to a well-formed kernel function.
1391 Operation *kernelFunc = symbolTable.lookupNearestSymbolFrom(
1392 kernelModule, launchOp.getKernelName());
1393 if (!kernelFunc)
1394 return launchOp.emitOpError("kernel function '")
1395 << launchOp.getKernel() << "' is undefined";
1396 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
1397 if (!kernelConvertedFunction) {
1398 InFlightDiagnostic diag = launchOp.emitOpError()
1399 << "referenced kernel '" << launchOp.getKernel()
1400 << "' is not a function";
1401 diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
1402 return diag;
1403 }
1404
1405 if (!GPUDialect::isKernel(kernelFunc))
1406 return launchOp.emitOpError("kernel function is missing the '")
1407 << GPUDialect::getKernelFuncAttrName() << "' attribute";
1408
1409 // TODO: If the kernel isn't a GPU function (which happens during separate
1410 // compilation), do not check type correspondence as it would require the
1411 // verifier to be aware of the type conversion.
1412 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
1413 if (!kernelGPUFunction)
1414 return success();
1415
1416 unsigned actualNumArguments = launchOp.getNumKernelOperands();
1417 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
1418 if (expectedNumArguments != actualNumArguments)
1419 return launchOp.emitOpError("got ")
1420 << actualNumArguments << " kernel operands but expected "
1421 << expectedNumArguments;
1422
1423 FunctionType functionType = kernelGPUFunction.getFunctionType();
1424 for (unsigned i = 0; i < expectedNumArguments; ++i) {
1425 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
1426 return launchOp.emitOpError("type of function argument ")
1427 << i << " does not match";
1428 }
1429 }
1430
1431 return success();
1432}
1433
1434static ParseResult
1436 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1437 Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1438 if (succeeded(parser.parseOptionalColon())) {
1439 if (parser.parseType(dimTy))
1440 return failure();
1441 } else {
1442 dimTy = IndexType::get(parser.getContext());
1443 }
1444 if (clusterValue.has_value()) {
1445 clusterXTy = clusterYTy = clusterZTy = dimTy;
1446 }
1447 return success();
1448}
1449
1450static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1451 Value clusterValue, Type clusterXTy,
1452 Type clusterYTy, Type clusterZTy) {
1453 if (!dimTy.isIndex())
1454 printer << ": " << dimTy;
1455}
1456
1457static ParseResult parseLaunchFuncOperands(
1458 OpAsmParser &parser,
1460 SmallVectorImpl<Type> &argTypes) {
1461 if (parser.parseOptionalKeyword("args"))
1462 return success();
1463
1464 auto parseElement = [&]() -> ParseResult {
1465 return failure(parser.parseOperand(argNames.emplace_back()) ||
1466 parser.parseColonType(argTypes.emplace_back()));
1467 };
1468
1470 parseElement, " in argument list");
1471}
1472
1474 OperandRange operands, TypeRange types) {
1475 if (operands.empty())
1476 return;
1477 printer << "args(";
1478 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1479 [&](const auto &pair) {
1480 auto [operand, type] = pair;
1481 printer << operand << " : " << type;
1482 });
1483 printer << ")";
1484}
1485
1486//===----------------------------------------------------------------------===//
1487// ShuffleOp
1488//===----------------------------------------------------------------------===//
1489
1490void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1491 int32_t offset, int32_t width, ShuffleMode mode) {
1492 build(builder, result, value,
1493 arith::ConstantOp::create(builder, result.location,
1494 builder.getI32IntegerAttr(offset)),
1495 arith::ConstantOp::create(builder, result.location,
1496 builder.getI32IntegerAttr(width)),
1497 mode);
1498}
1499
1500//===----------------------------------------------------------------------===//
1501// RotateOp
1502//===----------------------------------------------------------------------===//
1503
1504LogicalResult RotateOp::verify() {
1505 uint32_t offset = getOffset();
1506 uint32_t width = getWidth();
1507
1508 if (offset >= width) {
1509 return emitOpError() << "offset must be in the range [0, " << width << ")";
1510 }
1511
1512 return success();
1513}
1514
1515//===----------------------------------------------------------------------===//
1516// BarrierOp
1517//===----------------------------------------------------------------------===//
1518
1519/// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1520static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1521 PatternRewriter &rewriter) {
1522 auto nextOp = dyn_cast_or_null<BarrierOp>(op->getNextNode());
1523 if (!nextOp)
1524 return failure();
1525
1526 std::optional<ArrayAttr> thisMemfence = op.getAddressSpaces();
1527 std::optional<ArrayAttr> nextMemfence = nextOp.getAddressSpaces();
1528
1529 if (thisMemfence) {
1530 rewriter.modifyOpInPlace(op, [&]() {
1531 if (!nextMemfence) {
1532 op.removeAddressSpacesAttr();
1533 return;
1534 }
1535 // Fast path - merge where the two barriers fence the same spaces.
1536 if (*thisMemfence == *nextMemfence) {
1537 return;
1538 }
1539
1540 llvm::SmallSetVector<Attribute, 4> mergedSpaces;
1541 for (Attribute attr : *thisMemfence)
1542 mergedSpaces.insert(attr);
1543 for (Attribute attr : *nextMemfence)
1544 mergedSpaces.insert(attr);
1545 op.setAddressSpacesAttr(rewriter.getArrayAttr(mergedSpaces.takeVector()));
1546 });
1547 }
1548
1549 rewriter.eraseOp(nextOp);
1550 return success();
1551}
1552
1553void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1554 MLIRContext *context) {
1556}
1557
1558void BarrierOp::build(mlir::OpBuilder &odsBuilder,
1559 mlir::OperationState &odsState,
1560 std::optional<AddressSpace> addressSpace) {
1561 ArrayAttr addressSpacesAttr;
1562 if (addressSpace)
1563 addressSpacesAttr = odsBuilder.getArrayAttr(
1564 AddressSpaceAttr::get(odsBuilder.getContext(), addressSpace.value()));
1565 build(odsBuilder, odsState, addressSpacesAttr);
1566}
1567
1568/// Builds a barrier that causes memory operations affecting `memrefToFence` to
1569/// be completed after the barrier is concluded. Currently, this means setting
1570/// the fenced address spaces to those of the given memref if it is a gpu
1571/// address space.
1572void BarrierOp::build(OpBuilder &builder, OperationState &odsState,
1573 Value memrefToFence) {
1574 std::optional<AddressSpace> addrSpaceToFence;
1575 if (auto memrefType = dyn_cast<BaseMemRefType>(memrefToFence.getType()))
1576 if (auto addrSpaceAttr = dyn_cast_if_present<gpu::AddressSpaceAttr>(
1577 memrefType.getMemorySpace()))
1578 addrSpaceToFence = addrSpaceAttr.getValue();
1579 return build(builder, odsState, addrSpaceToFence);
1580}
1581
1582//===----------------------------------------------------------------------===//
1583// GPUFuncOp
1584//===----------------------------------------------------------------------===//
1585
1586/// Adds a new block argument that corresponds to buffers located in
1587/// workgroup memory.
1588BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1589 int64_t cur = getWorkgroupAttributions().value_or(0);
1590 setWorkgroupAttributions(std::optional<int64_t>(cur + 1));
1591 return getBody().insertArgument(
1592 getFunctionType().getNumInputs() + static_cast<unsigned>(cur), type, loc);
1593}
1594
1595/// Adds a new block argument that corresponds to buffers located in
1596/// private memory.
1597BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1598 // Buffers on the private memory always come after buffers on the workgroup
1599 // memory.
1600 return getBody().addArgument(type, loc);
1601}
1602
1603void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1604 StringRef name, FunctionType type,
1605 TypeRange workgroupAttributions,
1606 TypeRange privateAttributions,
1607 ArrayRef<NamedAttribute> attrs) {
1608 OpBuilder::InsertionGuard g(builder);
1609
1611 builder.getStringAttr(name));
1612 result.addAttribute(getFunctionTypeAttrName(result.name),
1613 TypeAttr::get(type));
1614 result.addAttribute(getWorkgroupAttributionsAttrName(result.name),
1615 builder.getI64IntegerAttr(workgroupAttributions.size()));
1616 result.addAttributes(attrs);
1617 Region *body = result.addRegion();
1618 Block *entryBlock = builder.createBlock(body);
1619
1620 // TODO: Allow passing in proper locations here.
1621 for (Type argTy : type.getInputs())
1622 entryBlock->addArgument(argTy, result.location);
1623 for (Type argTy : workgroupAttributions)
1624 entryBlock->addArgument(argTy, result.location);
1625 for (Type argTy : privateAttributions)
1626 entryBlock->addArgument(argTy, result.location);
1627}
1628
1629/// Parses a GPU function memory attribution.
1630///
1631/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1632/// (`private` `(` ssa-id-and-type-list `)`)?
1633///
1634/// Note that this function parses only one of the two similar parts, with the
1635/// keyword provided as argument.
1636static ParseResult
1637parseAttributions(OpAsmParser &parser, StringRef keyword,
1639 Attribute &attributionAttrs) {
1640 // If we could not parse the keyword, just assume empty list and succeed.
1641 if (failed(parser.parseOptionalKeyword(keyword)))
1642 return success();
1643
1644 size_t existingArgs = args.size();
1645 ParseResult result =
1647 /*allowType=*/true, /*allowAttrs=*/true);
1648 if (failed(result))
1649 return result;
1650
1651 bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1652 [](const OpAsmParser::Argument &arg) -> bool {
1653 return arg.attrs && !arg.attrs.empty();
1654 });
1655 if (!hadAttrs) {
1656 attributionAttrs = nullptr;
1657 return result;
1658 }
1659
1660 Builder &builder = parser.getBuilder();
1661 SmallVector<Attribute> attributionAttrsVec;
1662 for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1663 if (!argument.attrs)
1664 attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1665 else
1666 attributionAttrsVec.push_back(argument.attrs);
1667 }
1668 attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1669 return result;
1670}
1671
1672/// Parses a GPU function.
1673///
1674/// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1675/// (`->` function-result-list)? memory-attribution `kernel`?
1676/// function-attributes? region
1677ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
1678 SmallVector<OpAsmParser::Argument> entryArgs;
1679 SmallVector<DictionaryAttr> resultAttrs;
1680 SmallVector<Type> resultTypes;
1681 bool isVariadic;
1682
1683 // Parse the function name.
1684 StringAttr nameAttr;
1686 result.attributes))
1687 return failure();
1688
1689 auto signatureLocation = parser.getCurrentLocation();
1691 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1692 resultAttrs)))
1693 return failure();
1694
1695 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1696 return parser.emitError(signatureLocation)
1697 << "gpu.func requires named arguments";
1698
1699 // Construct the function type. More types will be added to the region, but
1700 // not to the function type.
1701 Builder &builder = parser.getBuilder();
1702
1703 SmallVector<Type> argTypes;
1704 for (auto &arg : entryArgs)
1705 argTypes.push_back(arg.type);
1706 auto type = builder.getFunctionType(argTypes, resultTypes);
1707 result.addAttribute(getFunctionTypeAttrName(result.name),
1708 TypeAttr::get(type));
1709
1711 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1712 getResAttrsAttrName(result.name));
1713
1714 Attribute workgroupAttributionAttrs;
1715 // Parse workgroup memory attributions.
1716 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1717 entryArgs, workgroupAttributionAttrs)))
1718 return failure();
1719
1720 // Store the number of operands we just parsed as the number of workgroup
1721 // memory attributions.
1722 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1723 if (numWorkgroupAttrs != 0)
1724 result.addAttribute(
1725 GPUFuncOp::getWorkgroupAttributionsAttrName(result.name),
1726 builder.getI64IntegerAttr(numWorkgroupAttrs));
1727 if (workgroupAttributionAttrs)
1728 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1729 workgroupAttributionAttrs);
1730
1731 Attribute privateAttributionAttrs;
1732 // Parse private memory attributions.
1733 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1734 entryArgs, privateAttributionAttrs)))
1735 return failure();
1736 if (privateAttributionAttrs)
1737 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1738 privateAttributionAttrs);
1739
1740 // Parse the kernel attribute if present.
1741 if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1742 result.addAttribute(GPUFuncOp::getKernelAttrName(result.name),
1743 builder.getUnitAttr());
1744
1745 // Parse attributes.
1746 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1747 return failure();
1748
1749 // Parse the region. If no argument names were provided, take all names
1750 // (including those of attributions) from the entry block.
1751 auto *body = result.addRegion();
1752 return parser.parseRegion(*body, entryArgs);
1753}
1754
1755void GPUFuncOp::print(OpAsmPrinter &p) {
1756 p << ' ';
1757 p.printSymbolName(getName());
1758
1759 FunctionType type = getFunctionType();
1760 function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1761 /*isVariadic=*/false,
1762 type.getResults());
1763
1764 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributionBBArgs(),
1765 getWorkgroupAttribAttrs().value_or(nullptr));
1766 printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1767 getPrivateAttribAttrs().value_or(nullptr));
1768 if (isKernel())
1769 p << ' ' << getKernelKeyword();
1770
1772 p, *this,
1773 {getWorkgroupAttributionsAttrName(), getKernelAttrName(),
1774 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1775 getArgAttrsAttrName(), getResAttrsAttrName(),
1776 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1777 p << ' ';
1778 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1779}
1780
1781static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1782 StringAttr attrName) {
1783 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1784 if (!allAttrs || index >= allAttrs.size())
1785 return DictionaryAttr();
1786 return llvm::cast<DictionaryAttr>(allAttrs[index]);
1787}
1788
1789DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1790 return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1791}
1792
1793DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1794 return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1795}
1796
1797static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1798 DictionaryAttr value, StringAttr attrName) {
1799 MLIRContext *ctx = op.getContext();
1800 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1801 SmallVector<Attribute> elements;
1802 if (allAttrs)
1803 elements.append(allAttrs.begin(), allAttrs.end());
1804 while (elements.size() <= index)
1805 elements.push_back(DictionaryAttr::get(ctx));
1806 if (!value)
1807 elements[index] = DictionaryAttr::get(ctx);
1808 else
1809 elements[index] = value;
1810 ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1811 op->setAttr(attrName, newValue);
1812}
1813
1814void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1815 DictionaryAttr value) {
1816 setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1817}
1818
1819void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1820 DictionaryAttr value) {
1821 setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1822}
1823
1824static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1825 StringAttr name, StringAttr attrsName) {
1826 DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1827 if (!dict)
1828 return Attribute();
1829 return dict.get(name);
1830}
1831
1832Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1833 StringAttr name) {
1834 assert(index < getNumWorkgroupAttributions() &&
1835 "index must map to a workgroup attribution");
1836 return getAttributionAttr(*this, index, name,
1837 getWorkgroupAttribAttrsAttrName());
1838}
1839
1840Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1841 StringAttr name) {
1842 assert(index < getNumPrivateAttributions() &&
1843 "index must map to a private attribution");
1844 return getAttributionAttr(*this, index, name,
1845 getPrivateAttribAttrsAttrName());
1846}
1847
1848static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1849 Attribute value, StringAttr attrsName) {
1850 MLIRContext *ctx = op.getContext();
1852 DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1853 if (oldDict)
1854 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1855
1856 bool found = false;
1857 bool mustSort = true;
1858 for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1859 if (elems[i].getName() == name) {
1860 found = true;
1861 if (!value) {
1862 std::swap(elems[i], elems[elems.size() - 1]);
1863 elems.pop_back();
1864 } else {
1865 mustSort = false;
1866 elems[i] = NamedAttribute(elems[i].getName(), value);
1867 }
1868 break;
1869 }
1870 }
1871 if (!found) {
1872 if (!value)
1873 return;
1874 elems.emplace_back(name, value);
1875 }
1876 if (mustSort) {
1877 DictionaryAttr::sortInPlace(elems);
1878 }
1879 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1880 setAttributionAttrs(op, index, newDict, attrsName);
1881}
1882
1883void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1884 Attribute value) {
1885 assert(index < getNumWorkgroupAttributions() &&
1886 "index must map to a workgroup attribution");
1887 setAttributionAttr(*this, index, name, value,
1888 getWorkgroupAttribAttrsAttrName());
1889}
1890
1891void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1892 Attribute value) {
1893 assert(index < getNumPrivateAttributions() &&
1894 "index must map to a private attribution");
1895 setAttributionAttr(*this, index, name, value,
1896 getPrivateAttribAttrsAttrName());
1897}
1898
1899LogicalResult GPUFuncOp::verifyType() {
1900 if (isKernel() && getFunctionType().getNumResults() != 0)
1901 return emitOpError() << "expected void return type for kernel function";
1902
1903 return success();
1904}
1905
1906/// Verifies the body of the function.
1907LogicalResult GPUFuncOp::verifyBody() {
1908 if (empty())
1909 return emitOpError() << "expected body with at least one block";
1910 unsigned numFuncArguments = getNumArguments();
1911 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1912 unsigned numBlockArguments = front().getNumArguments();
1913 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1914 return emitOpError() << "expected at least "
1915 << numFuncArguments + numWorkgroupAttributions
1916 << " arguments to body region";
1917
1918 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1919 for (unsigned i = 0; i < numFuncArguments; ++i) {
1920 Type blockArgType = front().getArgument(i).getType();
1921 if (funcArgTypes[i] != blockArgType)
1922 return emitOpError() << "expected body region argument #" << i
1923 << " to be of type " << funcArgTypes[i] << ", got "
1924 << blockArgType;
1925 }
1926
1927 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributionBBArgs(),
1928 GPUDialect::getWorkgroupAddressSpace())) ||
1929 failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1930 GPUDialect::getPrivateAddressSpace())))
1931 return failure();
1932
1933 return success();
1934}
1935
1936//===----------------------------------------------------------------------===//
1937// ReturnOp
1938//===----------------------------------------------------------------------===//
1939
1940LogicalResult gpu::ReturnOp::verify() {
1941 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1942
1943 FunctionType funType = function.getFunctionType();
1944
1945 if (funType.getNumResults() != getOperands().size())
1946 return emitOpError()
1947 .append("expected ", funType.getNumResults(), " result operands")
1948 .attachNote(function.getLoc())
1949 .append("return type declared here");
1950
1951 for (const auto &pair : llvm::enumerate(
1952 llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1953 auto [type, operand] = pair.value();
1954 if (type != operand.getType())
1955 return emitOpError() << "unexpected type `" << operand.getType()
1956 << "' for operand #" << pair.index();
1957 }
1958 return success();
1959}
1960
1961//===----------------------------------------------------------------------===//
1962// GPUModuleOp
1963//===----------------------------------------------------------------------===//
1964
1965void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1966 StringRef name, ArrayAttr targets,
1967 Attribute offloadingHandler) {
1968 result.addRegion()->emplaceBlock();
1969 Properties &props = result.getOrAddProperties<Properties>();
1970 if (targets)
1971 props.targets = targets;
1972 props.setSymName(builder.getStringAttr(name));
1973 props.offloadingHandler = offloadingHandler;
1974}
1975
1976void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1977 StringRef name, ArrayRef<Attribute> targets,
1978 Attribute offloadingHandler) {
1979 build(builder, result, name,
1980 targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1981 offloadingHandler);
1982}
1983
1984bool GPUModuleOp::hasTarget(Attribute target) {
1985 if (ArrayAttr targets = getTargetsAttr())
1986 return llvm::count(targets.getValue(), target);
1987 return false;
1988}
1989
1990void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1991 ArrayAttr &targetsAttr = getProperties().targets;
1992 SmallVector<Attribute> targetsVector(targets);
1993 targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1994}
1995
1996LogicalResult GPUModuleOp::verify() {
1997 auto targets = getOperation()->getAttrOfType<ArrayAttr>("targets");
1998
1999 if (!targets)
2000 return success();
2001
2002 for (auto target : targets) {
2003 if (auto verifyTargetAttr =
2004 llvm::dyn_cast<TargetAttrVerifyInterface>(target)) {
2005 if (verifyTargetAttr.verifyTarget(getOperation()).failed())
2006 return failure();
2007 }
2008 }
2009 return success();
2010}
2011
2012//===----------------------------------------------------------------------===//
2013// GPUBinaryOp
2014//===----------------------------------------------------------------------===//
2015void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
2016 Attribute offloadingHandler, ArrayAttr objects) {
2017 auto &properties = result.getOrAddProperties<Properties>();
2018 result.attributes.push_back(builder.getNamedAttr(
2020 properties.objects = objects;
2021 if (offloadingHandler)
2022 properties.offloadingHandler = offloadingHandler;
2023 else
2024 properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
2025}
2026
2027void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
2028 Attribute offloadingHandler, ArrayRef<Attribute> objects) {
2029 build(builder, result, name, offloadingHandler,
2030 objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
2031}
2032
2033static ParseResult parseOffloadingHandler(OpAsmParser &parser,
2034 Attribute &offloadingHandler) {
2035 if (succeeded(parser.parseOptionalLess())) {
2036 if (parser.parseAttribute(offloadingHandler))
2037 return failure();
2038 if (parser.parseGreater())
2039 return failure();
2040 }
2041 if (!offloadingHandler)
2042 offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
2043 return success();
2044}
2045
2047 Attribute offloadingHandler) {
2048 if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
2049 printer << '<' << offloadingHandler << '>';
2050}
2051
2052//===----------------------------------------------------------------------===//
2053// GPUMemcpyOp
2054//===----------------------------------------------------------------------===//
2055
2056LogicalResult MemcpyOp::verify() {
2057 auto srcType = getSrc().getType();
2058 auto dstType = getDst().getType();
2059
2060 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
2061 return emitOpError("arguments have incompatible element type");
2062
2063 if (failed(verifyCompatibleShape(srcType, dstType)))
2064 return emitOpError("arguments have incompatible shape");
2065
2066 return success();
2067}
2068
2069namespace {
2070
2071/// Erases a common case of copy ops where a destination value is used only by
2072/// the copy op, alloc and dealloc ops.
2073struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
2074 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
2075
2076 LogicalResult matchAndRewrite(MemcpyOp op,
2077 PatternRewriter &rewriter) const override {
2078 Value dest = op.getDst();
2079 Operation *destDefOp = dest.getDefiningOp();
2080 // `dest` must be defined by an op having Allocate memory effect in order to
2081 // perform the folding.
2082 if (!destDefOp ||
2084 return failure();
2085 // We can erase `op` iff `dest` has no other use apart from its
2086 // use by `op` and dealloc ops.
2087 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
2088 return user != op &&
2089 !hasSingleEffect<MemoryEffects::Free>(user, dest);
2090 }))
2091 return failure();
2092 // We can perform the folding if and only if op has a single async
2093 // dependency and produces an async token as result, or if it does not have
2094 // any async dependency and does not produce any async token result.
2095 if (op.getAsyncDependencies().size() > 1 ||
2096 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2097 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2098 return failure();
2099 rewriter.replaceOp(op, op.getAsyncDependencies());
2100 return success();
2101 }
2102};
2103
2104} // end anonymous namespace
2105
2106void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2107 MLIRContext *context) {
2108 results.add<EraseTrivialCopyOp>(context);
2109}
2110
2111//===----------------------------------------------------------------------===//
2112// GPU_SubgroupMmaLoadMatrixOp
2113//===----------------------------------------------------------------------===//
2114
2115LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2116 auto srcType = getSrcMemref().getType();
2117 auto resType = getRes().getType();
2118 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2119 auto operand = resMatrixType.getOperand();
2120 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2121
2122 if (!srcMemrefType.isLastDimUnitStride())
2123 return emitError(
2124 "expected source memref most minor dim must have unit stride");
2125
2126 if (operand != "AOp" && operand != "BOp" && operand != "COp")
2127 return emitError("only AOp, BOp and COp can be loaded");
2128
2129 return success();
2130}
2131
2132//===----------------------------------------------------------------------===//
2133// GPU_SubgroupMmaStoreMatrixOp
2134//===----------------------------------------------------------------------===//
2135
2136LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2137 auto srcType = getSrc().getType();
2138 auto dstType = getDstMemref().getType();
2139 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2140 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2141
2142 if (!dstMemrefType.isLastDimUnitStride())
2143 return emitError(
2144 "expected destination memref most minor dim must have unit stride");
2145
2146 if (srcMatrixType.getOperand() != "COp")
2147 return emitError(
2148 "expected the operand matrix being stored to have 'COp' operand type");
2149
2150 return success();
2151}
2152
2153//===----------------------------------------------------------------------===//
2154// GPU_SubgroupMmaComputeOp
2155//===----------------------------------------------------------------------===//
2156
2157LogicalResult SubgroupMmaComputeOp::verify() {
2158 enum OperandMap { A, B, C };
2159 SmallVector<MMAMatrixType, 3> opTypes;
2160 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
2161 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
2162 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
2163
2164 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
2165 opTypes[C].getOperand() != "COp")
2166 return emitError("operands must be in the order AOp, BOp, COp");
2167
2168 ArrayRef<int64_t> aShape, bShape, cShape;
2169 aShape = opTypes[A].getShape();
2170 bShape = opTypes[B].getShape();
2171 cShape = opTypes[C].getShape();
2172
2173 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2174 bShape[1] != cShape[1])
2175 return emitError("operand shapes do not satisfy matmul constraints");
2176
2177 return success();
2178}
2179
2180LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2181 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2182 return memref::foldMemRefCast(*this);
2183}
2184
2185LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2186 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2187 return memref::foldMemRefCast(*this);
2188}
2189
2190//===----------------------------------------------------------------------===//
2191// GPU_WaitOp
2192//===----------------------------------------------------------------------===//
2193
2194namespace {
2195
2196/// Remove gpu.wait op use of gpu.wait op def without async dependencies.
2197/// %t = gpu.wait async [] // No async dependencies.
2198/// ... gpu.wait ... [%t, ...] // %t can be removed.
2199struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2200public:
2202
2203 LogicalResult matchAndRewrite(WaitOp op,
2204 PatternRewriter &rewriter) const final {
2205 auto predicate = [](Value value) {
2206 auto waitOp = value.getDefiningOp<WaitOp>();
2207 return waitOp && waitOp->getNumOperands() == 0;
2208 };
2209 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2210 return failure();
2211 SmallVector<Value> validOperands;
2212 for (Value operand : op->getOperands()) {
2213 if (predicate(operand))
2214 continue;
2215 validOperands.push_back(operand);
2216 }
2217 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2218 return success();
2219 }
2220};
2221
2222/// Simplify trivial gpu.wait ops for the following patterns.
2223/// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2224/// dependencies).
2225/// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2226/// %t0.
2227/// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2228/// dependencies nor return any token.
2229struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2230public:
2232
2233 LogicalResult matchAndRewrite(WaitOp op,
2234 PatternRewriter &rewriter) const final {
2235 // Erase gpu.wait ops that neither have any async dependencies nor return
2236 // any async token.
2237 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2238 rewriter.eraseOp(op);
2239 return success();
2240 }
2241 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2242 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2243 op.getAsyncToken()) {
2244 rewriter.replaceOp(op, op.getAsyncDependencies());
2245 return success();
2246 }
2247 // Erase %t = gpu.wait async ... ops, where %t has no uses.
2248 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2249 rewriter.eraseOp(op);
2250 return success();
2251 }
2252 return failure();
2253 }
2254};
2255
2256} // end anonymous namespace
2257
2258void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2259 MLIRContext *context) {
2260 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2261}
2262
2263//===----------------------------------------------------------------------===//
2264// GPU_AllocOp
2265//===----------------------------------------------------------------------===//
2266
2267LogicalResult AllocOp::verify() {
2268 auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2269
2270 if (failed(verifyDynamicDimensionCount(getOperation(), memRefType,
2271 getDynamicSizes())))
2272 return failure();
2273
2274 unsigned numSymbols = 0;
2275 if (!memRefType.getLayout().isIdentity())
2276 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2277 if (getSymbolOperands().size() != numSymbols) {
2278 return emitOpError(
2279 "symbol operand count does not equal memref symbol count");
2280 }
2281
2282 return success();
2283}
2284
2285namespace {
2286
2287/// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2288/// `memref::AllocOp`.
2289struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2290 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2291
2292 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2293 PatternRewriter &rewriter) const override {
2294 std::optional<int64_t> index = dimOp.getConstantIndex();
2295 if (!index)
2296 return failure();
2297
2298 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2299 if (!memrefType || index.value() >= memrefType.getRank() ||
2300 !memrefType.isDynamicDim(index.value()))
2301 return failure();
2302
2303 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2304 if (!alloc)
2305 return failure();
2306
2307 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2308 memrefType.getDynamicDimIndex(index.value()));
2309 rewriter.replaceOp(dimOp, substituteOp);
2310 return success();
2311 }
2312};
2313
2314} // namespace
2315
2316void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2317 MLIRContext *context) {
2318 results.add<SimplifyDimOfAllocOp>(context);
2319}
2320
2321//===----------------------------------------------------------------------===//
2322// GPU object attribute
2323//===----------------------------------------------------------------------===//
2324
2325LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2326 Attribute target, CompilationTarget format,
2327 StringAttr object, DictionaryAttr properties,
2328 KernelTableAttr kernels) {
2329 if (!target)
2330 return emitError() << "the target attribute cannot be null";
2331 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2332 return success();
2333 return emitError() << "the target attribute must implement or promise the "
2334 "`gpu::TargetAttrInterface`";
2335}
2336
2337namespace {
2338ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2339 StringAttr &object) {
2340 std::optional<CompilationTarget> formatResult;
2341 StringRef enumKeyword;
2342 auto loc = odsParser.getCurrentLocation();
2343 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2344 formatResult = CompilationTarget::Fatbin;
2345 if (!formatResult &&
2346 (formatResult =
2347 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2348 odsParser.parseEqual())
2349 return odsParser.emitError(loc, "expected an equal sign");
2350 if (!formatResult)
2351 return odsParser.emitError(loc, "expected keyword for GPU object format");
2352 FailureOr<StringAttr> objectResult =
2353 FieldParser<StringAttr>::parse(odsParser);
2354 if (failed(objectResult))
2355 return odsParser.emitError(odsParser.getCurrentLocation(),
2356 "failed to parse GPU_ObjectAttr parameter "
2357 "'object' which is to be a `StringAttr`");
2358 format = *formatResult;
2359 object = *objectResult;
2360 return success();
2361}
2362
2363void printObject(AsmPrinter &odsParser, CompilationTarget format,
2364 StringAttr object) {
2365 if (format != CompilationTarget::Fatbin)
2366 odsParser << stringifyEnum(format) << " = ";
2367 odsParser << object;
2368}
2369} // namespace
2370
2371//===----------------------------------------------------------------------===//
2372// GPU select object attribute
2373//===----------------------------------------------------------------------===//
2374
2375LogicalResult
2376gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2377 Attribute target) {
2378 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2379 if (target) {
2380 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2381 if (intAttr.getInt() < 0) {
2382 return emitError() << "the object index must be positive";
2383 }
2384 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2385 return emitError()
2386 << "the target attribute must be a GPU Target attribute";
2387 }
2388 }
2389 return success();
2390}
2391
2392//===----------------------------------------------------------------------===//
2393// DynamicSharedMemoryOp
2394//===----------------------------------------------------------------------===//
2395
2396LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2397 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2398 return emitOpError() << "must be inside an op with symbol table";
2399
2400 MemRefType memrefType = getResultMemref().getType();
2401 // Check address space
2402 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2403 return emitOpError() << "address space must be "
2404 << gpu::AddressSpaceAttr::getMnemonic() << "<"
2405 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2406 }
2407 if (memrefType.hasStaticShape()) {
2408 return emitOpError() << "result memref type must be memref<?xi8, "
2409 "#gpu.address_space<workgroup>>";
2410 }
2411 return success();
2412}
2413
2414//===----------------------------------------------------------------------===//
2415// GPU WarpExecuteOnLane0Op
2416//===----------------------------------------------------------------------===//
2417
2418void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2419 p << "(" << getLaneid() << ")";
2420
2421 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2422 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2423 p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2424
2425 if (!getArgs().empty())
2426 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2427 if (!getResults().empty())
2428 p << " -> (" << getResults().getTypes() << ')';
2429 p << " ";
2430 p.printRegion(getRegion(),
2431 /*printEntryBlockArgs=*/true,
2432 /*printBlockTerminators=*/!getResults().empty());
2433 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2434}
2435
2436ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2437 OperationState &result) {
2438 // Create the region.
2439 result.regions.reserve(1);
2440 Region *warpRegion = result.addRegion();
2441
2442 auto &builder = parser.getBuilder();
2443 OpAsmParser::UnresolvedOperand laneId;
2444
2445 // Parse predicate operand.
2446 if (parser.parseLParen() ||
2447 parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2448 parser.parseRParen())
2449 return failure();
2450
2451 int64_t warpSize;
2452 if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2453 parser.parseRSquare())
2454 return failure();
2455 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2456 builder.getContext())),
2457 builder.getI64IntegerAttr(warpSize));
2458
2459 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2460 return failure();
2461
2462 llvm::SMLoc inputsOperandsLoc;
2463 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2464 SmallVector<Type> inputTypes;
2465 if (succeeded(parser.parseOptionalKeyword("args"))) {
2466 if (parser.parseLParen())
2467 return failure();
2468
2469 inputsOperandsLoc = parser.getCurrentLocation();
2470 if (parser.parseOperandList(inputsOperands) ||
2471 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2472 return failure();
2473 }
2474 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2475 result.operands))
2476 return failure();
2477
2478 // Parse optional results type list.
2479 if (parser.parseOptionalArrowTypeList(result.types))
2480 return failure();
2481 // Parse the region.
2482 if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2483 /*argTypes=*/{}))
2484 return failure();
2485 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2486
2487 // Parse the optional attribute list.
2488 if (parser.parseOptionalAttrDict(result.attributes))
2489 return failure();
2490 return success();
2491}
2492
2493void WarpExecuteOnLane0Op::getSuccessorRegions(
2494 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2495 if (!point.isParent()) {
2496 regions.push_back(RegionSuccessor::parent());
2497 return;
2498 }
2499
2500 // The warp region is always executed
2501 regions.push_back(RegionSuccessor(&getWarpRegion()));
2502}
2503
2504ValueRange WarpExecuteOnLane0Op::getSuccessorInputs(RegionSuccessor successor) {
2505 return successor.isParent() ? ValueRange(getResults()) : ValueRange();
2506}
2507void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2508 TypeRange resultTypes, Value laneId,
2509 int64_t warpSize) {
2510 build(builder, result, resultTypes, laneId, warpSize,
2511 /*operands=*/{}, /*argTypes=*/{});
2512}
2513
2514void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2515 TypeRange resultTypes, Value laneId,
2516 int64_t warpSize, ValueRange args,
2517 TypeRange blockArgTypes) {
2518 result.addOperands(laneId);
2519 result.addAttribute(getAttributeNames()[0],
2520 builder.getI64IntegerAttr(warpSize));
2521 result.addTypes(resultTypes);
2522 result.addOperands(args);
2523 assert(args.size() == blockArgTypes.size());
2524 OpBuilder::InsertionGuard guard(builder);
2525 Region *warpRegion = result.addRegion();
2526 Block *block = builder.createBlock(warpRegion);
2527 for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2528 block->addArgument(type, arg.getLoc());
2529}
2530
2531/// Helper check if the distributed vector type is consistent with the expanded
2532/// type and distributed size.
2533static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2534 int64_t warpSize, Operation *op) {
2535 // If the types matches there is no distribution.
2536 if (expanded == distributed)
2537 return success();
2538 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2539 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2540 if (!expandedVecType || !distributedVecType)
2541 return op->emitOpError("expected vector type for distributed operands.");
2542 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2543 expandedVecType.getElementType() != distributedVecType.getElementType())
2544 return op->emitOpError(
2545 "expected distributed vectors to have same rank and element type.");
2546
2547 SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2548 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2549 int64_t eDim = expandedVecType.getDimSize(i);
2550 int64_t dDim = distributedVecType.getDimSize(i);
2551 if (eDim == dDim)
2552 continue;
2553 if (eDim % dDim != 0)
2554 return op->emitOpError()
2555 << "expected expanded vector dimension #" << i << " (" << eDim
2556 << ") to be a multipler of the distributed vector dimension ("
2557 << dDim << ")";
2558 scales[i] = eDim / dDim;
2559 }
2560 if (llvm::product_of(scales) != warpSize)
2561 return op->emitOpError()
2562 << "incompatible distribution dimensions from " << expandedVecType
2563 << " to " << distributedVecType << " with warp size = " << warpSize;
2564
2565 return success();
2566}
2567
2568LogicalResult WarpExecuteOnLane0Op::verify() {
2569 if (getArgs().size() != getWarpRegion().getNumArguments())
2570 return emitOpError(
2571 "expected same number op arguments and block arguments.");
2572 auto yield = dyn_cast<gpu::YieldOp>(getBody()->getTerminator());
2573 if (!yield)
2574 return emitOpError("expected body to be terminated with 'gpu.yield'");
2575 if (yield.getNumOperands() != getNumResults())
2576 return emitOpError(
2577 "expected same number of yield operands and return values.");
2578 int64_t warpSize = getWarpSize();
2579 for (auto [regionArg, arg] :
2580 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2581 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2582 warpSize, getOperation())))
2583 return failure();
2584 }
2585 for (auto [yieldOperand, result] :
2586 llvm::zip_equal(yield.getOperands(), getResults())) {
2587 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2588 warpSize, getOperation())))
2589 return failure();
2590 }
2591 return success();
2592}
2593bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2594 return succeeded(
2595 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2596}
2597
2598gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2599 return cast<gpu::YieldOp>(getBody()->getTerminator());
2600}
2601
2602//===----------------------------------------------------------------------===//
2603// GPU_SubgroupBroadcastOp
2604//===----------------------------------------------------------------------===//
2605
2606void gpu::SubgroupBroadcastOp::inferResultRanges(
2607 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
2608 setResultRange(getResult(), argRanges.front());
2609}
2610
2611Speculation::Speculatability gpu::SubgroupBroadcastOp::getSpeculatability() {
2612 switch (getBroadcastType()) {
2613 case BroadcastType::first_active_lane:
2614 // Cannot speculate first_lane broadcast, because speculating it across
2615 // control flow can change the active lanes.
2617 case BroadcastType::specific_lane:
2618 // Speculation should be safe as long as we inside structured control flow.
2620 }
2621 llvm_unreachable("Unknown BroadcastType");
2622}
2623
2624LogicalResult gpu::SubgroupBroadcastOp::verify() {
2625 switch (getBroadcastType()) {
2626 case BroadcastType::first_active_lane:
2627 if (getLane())
2628 return emitOpError()
2629 << "lane can only be specified for `specific_lane` broadcast";
2630 return success();
2631 case BroadcastType::specific_lane:
2632 if (!getLane())
2633 return emitOpError()
2634 << "lane must be specified for `specific_lane` broadcast";
2635 return success();
2636 }
2637 llvm_unreachable("Unknown BroadcastType");
2638}
2639
2640OpFoldResult gpu::SubgroupBroadcastOp::fold(FoldAdaptor /*adaptor*/) {
2641 // Broadcast result is always uniform.
2642 if (auto prev = getSrc().getDefiningOp<SubgroupBroadcastOp>())
2643 return prev.getResult();
2644
2645 return nullptr;
2646}
2647
2648//===----------------------------------------------------------------------===//
2649// GPU_BallotOp
2650//===----------------------------------------------------------------------===//
2651
2652// No custom implementations needed; ballot uses default behavior from ODS.
2653
2654//===----------------------------------------------------------------------===//
2655// GPU KernelMetadataAttr
2656//===----------------------------------------------------------------------===//
2657
2658KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2659 DictionaryAttr metadata) {
2660 assert(kernel && "invalid kernel");
2661 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2662 kernel.getAllArgAttrs(), metadata);
2663}
2664
2665KernelMetadataAttr
2666KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2667 FunctionOpInterface kernel,
2668 DictionaryAttr metadata) {
2669 assert(kernel && "invalid kernel");
2670 return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2671 kernel.getAllArgAttrs(), metadata);
2672}
2673
2674KernelMetadataAttr
2675KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2676 if (attrs.empty())
2677 return *this;
2678 NamedAttrList attrList;
2679 if (DictionaryAttr dict = getMetadata())
2680 attrList.append(dict);
2681 attrList.append(attrs);
2682 return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2683 attrList.getDictionary(getContext()));
2684}
2685
2686LogicalResult
2687KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2688 StringAttr name, Type functionType,
2689 ArrayAttr argAttrs, DictionaryAttr metadata) {
2690 if (name.empty())
2691 return emitError() << "the kernel name can't be empty";
2692 if (argAttrs) {
2693 if (llvm::any_of(argAttrs, [](Attribute attr) {
2694 return !llvm::isa<DictionaryAttr>(attr);
2695 }))
2696 return emitError()
2697 << "all attributes in the array must be a dictionary attribute";
2698 }
2699 return success();
2700}
2701
2702//===----------------------------------------------------------------------===//
2703// GPU KernelTableAttr
2704//===----------------------------------------------------------------------===//
2705
2706KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2707 ArrayRef<KernelMetadataAttr> kernels,
2708 bool isSorted) {
2709 // Note that `is_sorted` is always only invoked once even with assertions ON.
2710 assert((!isSorted || llvm::is_sorted(kernels)) &&
2711 "expected a sorted kernel array");
2712 // Immediately return the attribute if the array is sorted.
2713 if (isSorted || llvm::is_sorted(kernels))
2714 return Base::get(context, kernels);
2715 // Sort the array.
2716 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2717 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2718 return Base::get(context, kernelsTmp);
2719}
2720
2721KernelTableAttr KernelTableAttr::getChecked(
2722 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2723 ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2724 // Note that `is_sorted` is always only invoked once even with assertions ON.
2725 assert((!isSorted || llvm::is_sorted(kernels)) &&
2726 "expected a sorted kernel array");
2727 // Immediately return the attribute if the array is sorted.
2728 if (isSorted || llvm::is_sorted(kernels))
2729 return Base::getChecked(emitError, context, kernels);
2730 // Sort the array.
2731 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2732 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2733 return Base::getChecked(emitError, context, kernelsTmp);
2734}
2735
2736LogicalResult
2737KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2738 ArrayRef<KernelMetadataAttr> kernels) {
2739 if (kernels.size() < 2)
2740 return success();
2741 // Check that the kernels are uniquely named.
2742 if (std::adjacent_find(kernels.begin(), kernels.end(),
2743 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2744 return l.getName() == r.getName();
2745 }) != kernels.end()) {
2746 return emitError() << "expected all kernels to be uniquely named";
2747 }
2748 return success();
2749}
2750
2751KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2752 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2753 return found ? *iterator : KernelMetadataAttr();
2754}
2755
2756KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2757 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2758 return found ? *iterator : KernelMetadataAttr();
2759}
2760
2761//===----------------------------------------------------------------------===//
2762// GPU target options
2763//===----------------------------------------------------------------------===//
2764
2779
2797
2798TypeID TargetOptions::getTypeID() const { return typeID; }
2799
2800StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2801
2805
2806StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2807
2808StringRef TargetOptions::getELFSection() const { return elfSection; }
2809
2813
2814function_ref<void(llvm::Module &)>
2818
2819function_ref<void(llvm::Module &)>
2823
2824function_ref<void(llvm::Module &)>
2828
2830 return isaCallback;
2831}
2832
2833CompilationTarget TargetOptions::getCompilationTarget() const {
2834 return compilationTarget;
2835}
2836
2838 return CompilationTarget::Fatbin;
2839}
2840
2841std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2843 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2844 llvm::StringSaver stringSaver(options.first);
2845 StringRef opts = cmdOptions;
2846 // For a correct tokenization of the command line options `opts` must be
2847 // unquoted, otherwise the tokenization function returns a single string: the
2848 // unquoted `cmdOptions` -which is not the desired behavior.
2849 // Remove any quotes if they are at the beginning and end of the string:
2850 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2851 opts.consume_front("\""), opts.consume_back("\"");
2852 if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2853 opts.consume_front("'"), opts.consume_back("'");
2854#ifdef _WIN32
2855 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2856 /*MarkEOLs=*/false);
2857#else
2858 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2859 /*MarkEOLs=*/false);
2860#endif // _WIN32
2861 return options;
2862}
2863
2864std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2868
2869std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2871 size_t startPos = cmdOptions.find(startsWith);
2872 if (startPos == std::string::npos)
2873 return {llvm::BumpPtrAllocator(), SmallVector<const char *>()};
2874
2875 auto tokenized =
2876 tokenizeCmdOptions(cmdOptions.substr(startPos + startsWith.size()));
2877 cmdOptions.resize(startPos);
2878 return tokenized;
2879}
2880
2882
2883#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2884#include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2885
2886#define GET_ATTRDEF_CLASSES
2887#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2888
2889#define GET_OP_CLASSES
2890#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2891
2892#include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values, ArrayAttr attributes={})
static LogicalResult verifyDistributedType(Type expanded, Type distributed, int64_t warpSize, Operation *op)
Helper check if the distributed vector type is consistent with the expanded type and distributed size...
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices, StringRef keyword)
static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op, PatternRewriter &rewriter)
Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition TypeID.h:323
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
UnitAttr getUnitAttr()
Definition Builders.cpp:102
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:80
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition Builders.cpp:108
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:98
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:576
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:608
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool isParent() const
Returns true if branching from the parent op.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:78
bool isIndex() const
Definition Types.cpp:56
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
function_ref< void(llvm::Module &)> optimizedLlvmIRCallback
Callback invoked with LLVM IR for the device module after LLVM optimizations but before codegen.
function_ref< void(StringRef)> getISACallback() const
Returns the callback invoked with the target ISA for the device, for example PTX assembly.
TypeID getTypeID() const
Returns the typeID.
std::string toolkitPath
Path to the target toolkit.
SymbolTable * getSymbolTable() const
Returns the result of the getSymbolTableCallback callback or a nullptr if no callback was provided.
StringRef getELFSection() const
Returns the ELF section.
StringRef getCmdOptions() const
Returns the command line options.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
function_ref< void(StringRef)> isaCallback
Callback invoked with the target ISA for the device, for example PTX assembly.
CompilationTarget compilationTarget
Compilation process target format.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
function_ref< void(llvm::Module &)> initialLlvmIRCallback
Callback invoked with the initial LLVM IR for the device module.
ArrayRef< Attribute > getLibrariesToLink() const
Returns the LLVM libraries to link to.
TargetOptions(StringRef toolkitPath={}, ArrayRef< Attribute > librariesToLink={}, StringRef cmdOptions={}, StringRef elfSection={}, CompilationTarget compilationTarget=getDefaultCompilationTarget(), function_ref< SymbolTable *()> getSymbolTableCallback={}, function_ref< void(llvm::Module &)> initialLlvmIRCallback={}, function_ref< void(llvm::Module &)> linkedLlvmIRCallback={}, function_ref< void(llvm::Module &)> optimizedLlvmIRCallback={}, function_ref< void(StringRef)> isaCallback={})
Constructor initializing the toolkit path, the list of files to link to, extra command line options,...
function_ref< void(llvm::Module &)> getOptimizedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after LLVM optimizations but before c...
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
StringRef getToolkitPath() const
Returns the toolkit path.
SmallVector< Attribute > librariesToLink
List of files to link with the LLVM module.
function_ref< void(llvm::Module &)> linkedLlvmIRCallback
Callback invoked with LLVM IR for the device module after linking the device libraries.
function_ref< void(llvm::Module &)> getInitialLlvmIRCallback() const
Returns the callback invoked with the initial LLVM IR for the device module.
function_ref< SymbolTable *()> getSymbolTableCallback
Callback for obtaining the parent symbol table of all the GPU modules being serialized.
static CompilationTarget getDefaultCompilationTarget()
Returns the default compilation target: CompilationTarget::Fatbin.
function_ref< void(llvm::Module &)> getLinkedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after linking the device libraries.
std::string elfSection
ELF Section where the binary needs to be located.
CompilationTarget getCompilationTarget() const
Returns the compilation target.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
std::pair< IteratorT, bool > findAttrSorted(IteratorT first, IteratorT last, StringRef name)
Using llvm::lower_bound requires an extra string comparison to check whether the returned iterator po...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto getChecked(function_ref< InFlightDiagnostic()> emitError, MLIRContext *context, Ts &&...params)
Helper method analogous to get, but uses getChecked when available to allow graceful failure on inval...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition GPUDialect.h:39