MLIR 23.0.0git
GPUDialect.cpp
Go to the documentation of this file.
1//===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the GPU kernel-related dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
20#include "mlir/IR/Attributes.h"
21#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinOps.h"
25#include "mlir/IR/Diagnostics.h"
27#include "mlir/IR/Matchers.h"
30#include "mlir/IR/SymbolTable.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include "llvm/Support/CommandLine.h"
39#include "llvm/Support/ErrorHandling.h"
40#include "llvm/Support/FormatVariadic.h"
41#include "llvm/Support/InterleavedRange.h"
42#include "llvm/Support/StringSaver.h"
43#include <cassert>
44#include <numeric>
45#include <optional>
46
47using namespace mlir;
48using namespace mlir::gpu;
49
50#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
51
52//===----------------------------------------------------------------------===//
53// GPU Device Mapping Attributes
54//===----------------------------------------------------------------------===//
55
56int64_t GPUBlockMappingAttr::getMappingId() const {
57 return static_cast<int64_t>(getBlock());
58}
59
60bool GPUBlockMappingAttr::isLinearMapping() const {
61 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
62}
63
64int64_t GPUBlockMappingAttr::getRelativeIndex() const {
65 return isLinearMapping()
66 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
67 : getMappingId();
68}
69
70int64_t GPUWarpgroupMappingAttr::getMappingId() const {
71 return static_cast<int64_t>(getWarpgroup());
72}
73
74bool GPUWarpgroupMappingAttr::isLinearMapping() const {
75 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
76}
77
78int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
79 return isLinearMapping()
80 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
81 : getMappingId();
82}
83
84int64_t GPUWarpMappingAttr::getMappingId() const {
85 return static_cast<int64_t>(getWarp());
86}
87
88bool GPUWarpMappingAttr::isLinearMapping() const {
89 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
90}
91
92int64_t GPUWarpMappingAttr::getRelativeIndex() const {
93 return isLinearMapping()
94 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
95 : getMappingId();
96}
97
98int64_t GPUThreadMappingAttr::getMappingId() const {
99 return static_cast<int64_t>(getThread());
100}
101
102bool GPUThreadMappingAttr::isLinearMapping() const {
103 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
104}
105
106int64_t GPUThreadMappingAttr::getRelativeIndex() const {
107 return isLinearMapping()
108 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
109 : getMappingId();
110}
111
112int64_t GPULaneMappingAttr::getMappingId() const {
113 return static_cast<int64_t>(getLane());
114}
115
116bool GPULaneMappingAttr::isLinearMapping() const {
117 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
118}
119
120int64_t GPULaneMappingAttr::getRelativeIndex() const {
121 return isLinearMapping()
122 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
123 : getMappingId();
124}
125
126int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds() const { return 64; }
127
128/// 8 4 0
129/// Example mask : 0 0 0 1 1 0 1 0 0
130///
131/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
132/// Logical id for e.g. 5 (2) constructs filter (1 << 5 - 1).
133///
134/// Example mask : 0 0 0 1 1 0 1 0 0
135/// Example filter: 0 0 0 0 1 1 1 1 1
136/// Intersection : 0 0 0 0 1 0 1 0 0
137/// PopCnt : 2
138Value GPUMappingMaskAttr::createLogicalLinearMappingId(
139 OpBuilder &b, Value physicalLinearMappingId) const {
140 Location loc = physicalLinearMappingId.getLoc();
141 Value mask =
142 arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
143 Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
144 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
145 filter = arith::SubIOp::create(b, loc, filter, one);
146 Value filteredId = arith::AndIOp::create(b, loc, mask, filter);
147 return math::CtPopOp::create(b, loc, filteredId);
148}
149
150/// 8 4 0
151/// Example mask : 0 0 0 1 1 0 1 0 0
152///
153/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
154/// Logical id for e.g. 5 (2) constructs filter (1 << 5).
155///
156/// Example mask : 0 0 0 1 1 0 1 0 0
157/// Example filter: 0 0 0 1 0 0 0 0 0
158/// Intersection : 0 0 0 1 0 0 0 0 0
159/// Cmp : 1
160Value GPUMappingMaskAttr::createIsActiveIdPredicate(
161 OpBuilder &b, Value physicalLinearMappingId) const {
162 Location loc = physicalLinearMappingId.getLoc();
163 Value mask =
164 arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
165 Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
166 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
167 Value filtered = arith::AndIOp::create(b, loc, mask, filter);
168 Value zero = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(0));
169 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ne, filtered,
170 zero);
171}
172
173int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
174 return static_cast<int64_t>(getAddressSpace());
175}
176
177bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
178 llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
179}
180
181int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
182 llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
183}
184
185//===----------------------------------------------------------------------===//
186// MMAMatrixType
187//===----------------------------------------------------------------------===//
188
190 StringRef operand) {
191 return Base::get(elementType.getContext(), shape, elementType, operand);
192}
193
196 ArrayRef<int64_t> shape, Type elementType,
197 StringRef operand) {
198 return Base::getChecked(emitError, elementType.getContext(), shape,
199 elementType, operand);
200}
201
202unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
203
205 return getImpl()->getShape();
206}
207
208Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
209
210StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
211
213 return elementType.isF16() || elementType.isF32() || elementType.isF64() ||
214 elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
215 elementType.isInteger(32);
216}
217
218LogicalResult
220 ArrayRef<int64_t> shape, Type elementType,
221 StringRef operand) {
222 if (operand != "AOp" && operand != "BOp" && operand != "COp")
223 return emitError() << "operand expected to be one of AOp, BOp or COp";
224
225 if (shape.size() != 2)
226 return emitError() << "MMAMatrixType must have exactly two dimensions";
227
228 if (!MMAMatrixType::isValidElementType(elementType))
229 return emitError()
230 << "MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
231
232 return success();
233}
234
235//===----------------------------------------------------------------------===//
236// GPUDialect
237//===----------------------------------------------------------------------===//
238
239bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
240 if (!memorySpace)
241 return false;
242 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
243 return gpuAttr.getValue() == getWorkgroupAddressSpace();
244 return false;
245}
246
247bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
248 Attribute memorySpace = type.getMemorySpace();
249 return isWorkgroupMemoryAddressSpace(memorySpace);
250}
251
252bool GPUDialect::isConstantMemoryAddressSpace(Attribute memorySpace) {
253 if (!memorySpace)
254 return false;
255 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
256 return gpuAttr.getValue() == getConstantAddressSpace();
257 return false;
258}
259
260bool GPUDialect::hasConstantMemoryAddressSpace(MemRefType type) {
261 Attribute memorySpace = type.getMemorySpace();
262 return isConstantMemoryAddressSpace(memorySpace);
263}
264
265bool GPUDialect::isKernel(Operation *op) {
266 if (auto gpuFunc = dyn_cast<GPUFuncOp>(op))
267 return gpuFunc.isKernel();
268 return static_cast<bool>(
269 op->getAttrOfType<UnitAttr>(getKernelFuncAttrName()));
270}
271
272namespace {
273/// This class defines the interface for handling inlining with gpu
274/// operations.
275struct GPUInlinerInterface : public DialectInlinerInterface {
276 using DialectInlinerInterface::DialectInlinerInterface;
277
278 /// All gpu dialect ops can be inlined.
279 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
280 return true;
281 }
282};
283} // namespace
284
285void GPUDialect::initialize() {
286 addTypes<AsyncTokenType>();
287 addTypes<MMAMatrixType>();
288 addTypes<NamedBarrierType>();
289 addTypes<SparseDnTensorHandleType>();
290 addTypes<SparseSpMatHandleType>();
291 addTypes<SparseSpGEMMOpHandleType>();
292 addOperations<
293#define GET_OP_LIST
294#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
295 >();
296 addAttributes<
297#define GET_ATTRDEF_LIST
298#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
299 >();
300 addInterfaces<GPUInlinerInterface>();
301 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
302 TerminatorOp>();
303 declarePromisedInterfaces<ValueBoundsOpInterface, ClusterDimOp,
304 ClusterDimBlocksOp, ClusterIdOp, ClusterBlockIdOp,
305 BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp,
306 LaneIdOp, SubgroupIdOp, GlobalIdOp, NumSubgroupsOp,
307 SubgroupSizeOp, LaunchOp, SubgroupBroadcastOp>();
308 declarePromisedInterfaces<memref::IndexedAccessOpInterface,
309 SubgroupMmaLoadMatrixOp,
310 SubgroupMmaStoreMatrixOp>();
311}
312
313static std::string getSparseHandleKeyword(SparseHandleKind kind) {
314 switch (kind) {
316 return "sparse.dntensor_handle";
318 return "sparse.spmat_handle";
320 return "sparse.spgemmop_handle";
321 }
322 llvm_unreachable("unknown sparse handle kind");
323 return "";
324}
325
326Type GPUDialect::parseType(DialectAsmParser &parser) const {
327 // Parse the main keyword for the type.
328 StringRef keyword;
329 if (parser.parseKeyword(&keyword))
330 return Type();
331 MLIRContext *context = getContext();
332
333 // Handle 'async token' types.
334 if (keyword == "async.token")
335 return AsyncTokenType::get(context);
336
337 if (keyword == "mma_matrix") {
338 SMLoc beginLoc = parser.getNameLoc();
339
340 // Parse '<'.
341 if (parser.parseLess())
342 return nullptr;
343
344 // Parse the size and elementType.
345 SmallVector<int64_t> shape;
346 Type elementType;
347 if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
348 parser.parseType(elementType))
349 return nullptr;
350
351 // Parse ','
352 if (parser.parseComma())
353 return nullptr;
354
355 // Parse operand.
356 std::string operand;
357 if (failed(parser.parseOptionalString(&operand)))
358 return nullptr;
359
360 // Parse '>'.
361 if (parser.parseGreater())
362 return nullptr;
363
365 parser.getEncodedSourceLoc(beginLoc)),
366 shape, elementType, operand);
367 }
368
369 if (keyword == "named_barrier")
370 return NamedBarrierType::get(context);
371
373 return SparseDnTensorHandleType::get(context);
375 return SparseSpMatHandleType::get(context);
377 return SparseSpGEMMOpHandleType::get(context);
378
379 parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
380 return Type();
381}
382// TODO: print refined type here. Notice that should be corresponding to the
383// parser
384void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
385 TypeSwitch<Type>(type)
386 .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
387 .Case<NamedBarrierType>([&](Type) { os << "named_barrier"; })
388 .Case<SparseDnTensorHandleType>([&](Type) {
390 })
391 .Case<SparseSpMatHandleType>(
393 .Case<SparseSpGEMMOpHandleType>([&](Type) {
395 })
396 .Case([&](MMAMatrixType fragTy) {
397 os << "mma_matrix<";
398 auto shape = fragTy.getShape();
399 for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
400 os << *dim << 'x';
401 os << shape.back() << 'x' << fragTy.getElementType();
402 os << ", \"" << fragTy.getOperand() << "\"" << '>';
403 })
404 .DefaultUnreachable("unexpected 'gpu' type kind");
405}
406
407static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
408 NamedAttribute attr) {
409 auto array = dyn_cast<DenseI32ArrayAttr>(attr.getValue());
410 if (!array)
411 return op->emitOpError(Twine(attr.getName()) +
412 " must be a dense i32 array");
413 if (array.size() != 3)
414 return op->emitOpError(Twine(attr.getName()) +
415 " must contain exactly 3 elements");
416 return success();
417}
418
419LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
420 NamedAttribute attr) {
421 if (attr.getName() == getKnownBlockSizeAttrHelper().getName())
422 return verifyKnownLaunchSizeAttr(op, attr);
423 if (attr.getName() == getKnownGridSizeAttrHelper().getName())
424 return verifyKnownLaunchSizeAttr(op, attr);
425 if (attr.getName() == getKnownClusterSizeAttrHelper().getName())
426 return verifyKnownLaunchSizeAttr(op, attr);
427 if (!llvm::isa<UnitAttr>(attr.getValue()) ||
428 attr.getName() != getContainerModuleAttrName())
429 return success();
430
431 auto module = dyn_cast<ModuleOp>(op);
432 if (!module)
433 return op->emitError("expected '")
434 << getContainerModuleAttrName() << "' attribute to be attached to '"
435 << ModuleOp::getOperationName() << '\'';
436 return success();
437}
438
439/// Parses an optional list of async operands with an optional leading keyword.
440/// (`async`)? (`[` ssa-id-list `]`)?
441///
442/// This method is used by the tablegen assembly format for async ops as well.
443static ParseResult parseAsyncDependencies(
444 OpAsmParser &parser, Type &asyncTokenType,
446 auto loc = parser.getCurrentLocation();
447 if (succeeded(parser.parseOptionalKeyword("async"))) {
448 if (parser.getNumResults() == 0)
449 return parser.emitError(loc, "needs to be named when marked 'async'");
450 asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
451 }
452 return parser.parseOperandList(asyncDependencies,
454}
455
456/// Prints optional async dependencies with its leading keyword.
457/// (`async`)? (`[` ssa-id-list `]`)?
458// Used by the tablegen assembly format for several async ops.
460 Type asyncTokenType,
461 OperandRange asyncDependencies) {
462 if (asyncTokenType)
463 printer << "async";
464 if (asyncDependencies.empty())
465 return;
466 if (asyncTokenType)
467 printer << ' ';
468 printer << llvm::interleaved_array(asyncDependencies);
469}
470
471// GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
472/// Parses a GPU function memory attribution.
473///
474/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
475/// (`private` `(` ssa-id-and-type-list `)`)?
476///
477/// Note that this function parses only one of the two similar parts, with the
478/// keyword provided as argument.
479static ParseResult
480parseAttributions(OpAsmParser &parser, StringRef keyword,
482 // If we could not parse the keyword, just assume empty list and succeed.
483 if (failed(parser.parseOptionalKeyword(keyword)))
484 return success();
485
487 /*allowType=*/true);
488}
489
490static void printAttributions(OpAsmPrinter &p, StringRef keyword,
492 ArrayAttr attributes = {}) {
493 if (values.empty())
494 return;
495
496 p << ' ' << keyword << '(';
497 llvm::interleaveComma(
498 llvm::enumerate(values), p, [&p, attributes](auto pair) {
499 BlockArgument v = pair.value();
500 p << v << " : " << v.getType();
501
502 size_t attributionIndex = pair.index();
503 DictionaryAttr attrs;
504 if (attributes && attributionIndex < attributes.size())
505 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
506 if (attrs)
507 p.printOptionalAttrDict(attrs.getValue());
508 });
509 p << ')';
510}
511
512/// Verifies a GPU function memory attribution.
513static LogicalResult verifyAttributions(Operation *op,
514 ArrayRef<BlockArgument> attributions,
515 gpu::AddressSpace memorySpace) {
516 for (Value v : attributions) {
517 auto type = llvm::dyn_cast<MemRefType>(v.getType());
518 if (!type)
519 return op->emitOpError() << "expected memref type in attribution";
520
521 // We can only verify the address space if it hasn't already been lowered
522 // from the AddressSpaceAttr to a target-specific numeric value.
523 auto addressSpace =
524 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
525 if (!addressSpace)
526 continue;
527 if (addressSpace.getValue() != memorySpace)
528 return op->emitOpError()
529 << "expected memory space " << stringifyAddressSpace(memorySpace)
530 << " in attribution";
531 }
532 return success();
533}
534
535//===----------------------------------------------------------------------===//
536// AllReduceOp
537//===----------------------------------------------------------------------===//
538
539static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
540 Type resType) {
541 using Kind = gpu::AllReduceOperation;
542 if (llvm::is_contained(
543 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
544 opName)) {
545 if (!isa<FloatType>(resType))
546 return failure();
547 }
548
549 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
550 Kind::AND, Kind::OR, Kind::XOR},
551 opName)) {
552 if (!isa<IntegerType>(resType))
553 return failure();
554 }
555
556 return success();
557}
558
559LogicalResult gpu::AllReduceOp::verifyRegions() {
560 if (getBody().empty() != getOp().has_value())
561 return emitError("expected either an op attribute or a non-empty body");
562 if (!getBody().empty()) {
563 if (getBody().getNumArguments() != 2)
564 return emitError("expected two region arguments");
565 for (auto argument : getBody().getArguments()) {
566 if (argument.getType() != getType())
567 return emitError("incorrect region argument type");
568 }
569 unsigned yieldCount = 0;
570 for (Block &block : getBody()) {
571 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
572 if (yield.getNumOperands() != 1)
573 return emitError("expected one gpu.yield operand");
574 if (yield.getOperand(0).getType() != getType())
575 return emitError("incorrect gpu.yield type");
576 ++yieldCount;
577 }
578 }
579 if (yieldCount == 0)
580 return emitError("expected gpu.yield op in region");
581 } else {
582 gpu::AllReduceOperation opName = *getOp();
583 if (failed(verifyReduceOpAndType(opName, getType()))) {
584 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
585 << "` reduction operation is not compatible with type "
586 << getType();
587 }
588 }
589
590 return success();
591}
592
594 auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
595 if (!launchOp)
596 return false;
597
598 Region &body = launchOp.getBody();
599 assert(!body.empty() && "Invalid region");
600
601 // Only convert ops in gpu::launch entry block for now.
602 return op->getBlock() == &body.front();
603}
604
605OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
606 if (!getUniform() && canMakeGroupOpUniform(*this)) {
607 setUniform(true);
608 return getResult();
609 }
610
611 return nullptr;
612}
613
614// TODO: Support optional custom attributes (without dialect prefix).
615static ParseResult parseAllReduceOperation(AsmParser &parser,
616 AllReduceOperationAttr &attr) {
617 StringRef enumStr;
618 if (!parser.parseOptionalKeyword(&enumStr)) {
619 std::optional<AllReduceOperation> op =
620 gpu::symbolizeAllReduceOperation(enumStr);
621 if (!op)
622 return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
623 attr = AllReduceOperationAttr::get(parser.getContext(), *op);
624 }
625 return success();
626}
627
629 AllReduceOperationAttr attr) {
630 if (attr)
631 attr.print(printer);
632}
633
634//===----------------------------------------------------------------------===//
635// SubgroupReduceOp
636//===----------------------------------------------------------------------===//
637
638LogicalResult gpu::SubgroupReduceOp::verify() {
639 Type elemType = getType();
640 if (auto vecTy = dyn_cast<VectorType>(elemType)) {
641 if (vecTy.isScalable())
642 return emitOpError() << "is not compatible with scalable vector types";
643
644 elemType = vecTy.getElementType();
645 }
646
647 gpu::AllReduceOperation opName = getOp();
648 if (failed(verifyReduceOpAndType(opName, elemType))) {
649 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
650 << "` reduction operation is not compatible with type "
651 << getType();
652 }
653
654 auto clusterSize = getClusterSize();
655 if (clusterSize) {
656 uint32_t size = *clusterSize;
657 if (!llvm::isPowerOf2_32(size)) {
658 return emitOpError() << "cluster size " << size
659 << " is not a power of two";
660 }
661 }
662
663 uint32_t stride = getClusterStride();
664 if (stride != 1 && !clusterSize) {
665 return emitOpError() << "cluster stride can only be specified if cluster "
666 "size is specified";
667 }
668 if (!llvm::isPowerOf2_32(stride)) {
669 return emitOpError() << "cluster stride " << stride
670 << " is not a power of two";
671 }
672
673 return success();
674}
675
676OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
677 if (getClusterSize() == 1)
678 return getValue();
679
680 if (!getUniform() && canMakeGroupOpUniform(*this)) {
681 setUniform(true);
682 return getResult();
683 }
684
685 return nullptr;
686}
687
688//===----------------------------------------------------------------------===//
689// AsyncOpInterface
690//===----------------------------------------------------------------------===//
691
693 op->insertOperands(0, {token});
694 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
695 return;
696 auto attrName =
698 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
699
700 // Async dependencies is the only variadic operand.
701 if (!sizeAttr)
702 return;
703
704 SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
705 ++sizes.front();
706 op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
707}
708
709//===----------------------------------------------------------------------===//
710// LaunchOp
711//===----------------------------------------------------------------------===//
712
713void LaunchOp::build(OpBuilder &builder, OperationState &result,
714 Value gridSizeX, Value gridSizeY, Value gridSizeZ,
715 Value getBlockSizeX, Value getBlockSizeY,
716 Value getBlockSizeZ, Value dynamicSharedMemorySize,
717 Type asyncTokenType, ValueRange asyncDependencies,
718 TypeRange workgroupAttributions,
719 TypeRange privateAttributions, Value clusterSizeX,
720 Value clusterSizeY, Value clusterSizeZ,
721 FlatSymbolRefAttr module, FlatSymbolRefAttr function) {
722 OpBuilder::InsertionGuard g(builder);
723
724 if (!workgroupAttributions.empty())
725 result.addAttribute(
726 getWorkgroupAttributionsAttrName(result.name),
727 builder.getI64IntegerAttr(workgroupAttributions.size()));
728
729 // Add Op operands.
730 result.addOperands(asyncDependencies);
731 if (asyncTokenType)
732 result.types.push_back(builder.getType<AsyncTokenType>());
733
734 // Add grid and block sizes as op operands, followed by the data operands.
735 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
736 getBlockSizeY, getBlockSizeZ});
737 if (clusterSizeX)
738 result.addOperands(clusterSizeX);
739 if (clusterSizeY)
740 result.addOperands(clusterSizeY);
741 if (clusterSizeZ)
742 result.addOperands(clusterSizeZ);
743 if (dynamicSharedMemorySize)
744 result.addOperands(dynamicSharedMemorySize);
745
746 // Add optional module and function attributes.
747 if (module)
748 result.addAttribute(getModuleAttrName(result.name), module);
749 if (function)
750 result.addAttribute(getFunctionAttrName(result.name), function);
751
752 // Create a kernel body region with kNumConfigRegionAttributes + N memory
753 // attributions, where the first kNumConfigRegionAttributes arguments have
754 // `index` type and the rest have the same types as the data operands.
755 Region *kernelRegion = result.addRegion();
756 Block *body = builder.createBlock(kernelRegion);
757 // TODO: Allow passing in proper locations here.
758 for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
759 body->addArgument(builder.getIndexType(), result.location);
760 // Add WorkGroup & Private attributions to the region arguments.
761 for (Type argTy : workgroupAttributions)
762 body->addArgument(argTy, result.location);
763 for (Type argTy : privateAttributions)
764 body->addArgument(argTy, result.location);
765 // Fill OperandSegmentSize Attribute.
766 SmallVector<int32_t, 11> segmentSizes(11, 1);
767 segmentSizes.front() = asyncDependencies.size();
768 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
769 segmentSizes[7] = clusterSizeX ? 1 : 0;
770 segmentSizes[8] = clusterSizeY ? 1 : 0;
771 segmentSizes[9] = clusterSizeZ ? 1 : 0;
772 result.addAttribute(getOperandSegmentSizeAttr(),
773 builder.getDenseI32ArrayAttr(segmentSizes));
774}
775
776KernelDim3 LaunchOp::getBlockIds() {
777 assert(!getBody().empty() && "LaunchOp body must not be empty.");
778 auto args = getBody().getArguments();
779 return KernelDim3{args[0], args[1], args[2]};
780}
781
782KernelDim3 LaunchOp::getThreadIds() {
783 assert(!getBody().empty() && "LaunchOp body must not be empty.");
784 auto args = getBody().getArguments();
785 return KernelDim3{args[3], args[4], args[5]};
786}
787
788KernelDim3 LaunchOp::getGridSize() {
789 assert(!getBody().empty() && "LaunchOp body must not be empty.");
790 auto args = getBody().getArguments();
791 return KernelDim3{args[6], args[7], args[8]};
792}
793
794KernelDim3 LaunchOp::getBlockSize() {
795 assert(!getBody().empty() && "LaunchOp body must not be empty.");
796 auto args = getBody().getArguments();
797 return KernelDim3{args[9], args[10], args[11]};
798}
799
800std::optional<KernelDim3> LaunchOp::getClusterIds() {
801 assert(!getBody().empty() && "LaunchOp body must not be empty.");
802 if (!hasClusterSize())
803 return std::nullopt;
804 auto args = getBody().getArguments();
805 return KernelDim3{args[12], args[13], args[14]};
806}
807
808std::optional<KernelDim3> LaunchOp::getClusterSize() {
809 assert(!getBody().empty() && "LaunchOp body must not be empty.");
810 if (!hasClusterSize())
811 return std::nullopt;
812 auto args = getBody().getArguments();
813 return KernelDim3{args[15], args[16], args[17]};
814}
815
816KernelDim3 LaunchOp::getGridSizeOperandValues() {
817 auto operands = getOperands().drop_front(getAsyncDependencies().size());
818 return KernelDim3{operands[0], operands[1], operands[2]};
819}
820
821KernelDim3 LaunchOp::getBlockSizeOperandValues() {
822 auto operands = getOperands().drop_front(getAsyncDependencies().size());
823 return KernelDim3{operands[3], operands[4], operands[5]};
824}
825
826std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
827 auto operands = getOperands().drop_front(getAsyncDependencies().size());
828 if (!hasClusterSize())
829 return std::nullopt;
830 return KernelDim3{operands[6], operands[7], operands[8]};
831}
832
833LogicalResult LaunchOp::verify() {
834 if (!(hasClusterSize()) &&
835 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
836 return emitOpError() << "cluster size must be all present";
837 return success();
838}
839
840LogicalResult LaunchOp::verifyRegions() {
841 // Kernel launch takes kNumConfigOperands leading operands for grid/block
842 // sizes and transforms them into kNumConfigRegionAttributes region arguments
843 // for block/thread identifiers and grid/block sizes.
844 if (getBody().empty()) {
845 return emitOpError("body region is empty");
846 }
847 unsigned actualNumRegionArgs = getBody().getNumArguments();
848 unsigned expectedNumRegionArgs =
849 getNumConfigRegionAttributes() + getNumWorkgroupAttributions();
850 if (actualNumRegionArgs < expectedNumRegionArgs) {
851 return emitOpError("expected at least ")
852 << expectedNumRegionArgs << " region arguments, but got "
853 << actualNumRegionArgs;
854 }
855
856 // Verify Attributions Address Spaces.
857 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributionBBArgs(),
858 GPUDialect::getWorkgroupAddressSpace())) ||
859 failed(verifyAttributions(getOperation(), getPrivateAttributions(),
860 GPUDialect::getPrivateAddressSpace())))
861 return failure();
862
863 // Block terminators without successors are expected to exit the kernel region
864 // and must be `gpu.terminator`.
865 for (Block &block : getBody()) {
866 if (block.empty())
867 continue;
868 if (block.back().getNumSuccessors() != 0)
869 continue;
870 if (!isa<gpu::TerminatorOp>(&block.back())) {
871 return block.back()
872 .emitError()
873 .append("expected '", gpu::TerminatorOp::getOperationName(),
874 "' or a terminator with successors")
875 .attachNote(getLoc())
876 .append("in '", LaunchOp::getOperationName(), "' body region");
877 }
878 }
879
880 if (getNumResults() == 0 && getAsyncToken())
881 return emitOpError("needs to be named when async keyword is specified");
882
883 return success();
884}
885
886// Pretty-print the kernel grid/block size assignment as
887// (%iter-x, %iter-y, %iter-z) in
888// (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
889// where %size-* and %iter-* will correspond to the body region arguments.
891 KernelDim3 operands, KernelDim3 ids) {
892 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
893 p << size.x << " = " << operands.x << ", ";
894 p << size.y << " = " << operands.y << ", ";
895 p << size.z << " = " << operands.z << ')';
896}
897
898void LaunchOp::print(OpAsmPrinter &p) {
899 if (getAsyncToken()) {
900 p << " async";
901 if (!getAsyncDependencies().empty())
902 p << " [" << getAsyncDependencies() << ']';
903 }
904 // Print the launch configuration.
905 if (hasClusterSize()) {
906 p << ' ' << getClustersKeyword();
907 printSizeAssignment(p, getClusterSize().value(),
908 getClusterSizeOperandValues().value(),
909 getClusterIds().value());
910 }
911 p << ' ' << getBlocksKeyword();
912 printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
913 getBlockIds());
914 p << ' ' << getThreadsKeyword();
915 printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
916 getThreadIds());
917 if (getDynamicSharedMemorySize())
918 p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
919 << getDynamicSharedMemorySize();
920
921 // Print optional module attribute.
922 StringRef moduleAttrName = getModuleAttrName();
923 if (auto module = getModule()) {
924 p << ' ' << moduleAttrName << '(';
925 p.printSymbolName(*module);
926 p << ')';
927 }
928 // Print optional function attribute.
929 StringRef functionAttrName = getFunctionAttrName();
930 if (auto function = getFunction()) {
931 p << ' ' << functionAttrName << '(';
932 p.printSymbolName(*function);
933 p << ')';
934 }
935
936 if (getCooperative())
937 p << " cooperative";
938
939 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributionBBArgs());
940 printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
941
942 p << ' ';
943
944 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
945 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
946 LaunchOp::getOperandSegmentSizeAttr(),
947 getWorkgroupAttributionsAttrName(),
948 getCooperativeAttrName(), moduleAttrName,
949 functionAttrName});
950}
951
952// Parse the size assignment blocks for blocks and threads. These have the form
953// (%region_arg, %region_arg, %region_arg) in
954// (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
955// where %region_arg are percent-identifiers for the region arguments to be
956// introduced further (SSA defs), and %operand are percent-identifiers for the
957// SSA value uses.
958static ParseResult
963 StringRef keyword) {
964 assert(indices.size() == 3 && "space for three indices expected");
967 /*allowResultNumber=*/false) ||
968 parser.parseKeyword("in") || parser.parseLParen())
969 return failure();
970
971 if (args.size() != 3) {
972 return parser.emitError(parser.getNameLoc())
973 << keyword << " expects 3 arguments, but got " << args.size();
974 }
975 std::move(args.begin(), args.end(), indices.begin());
976
977 for (int i = 0; i < 3; ++i) {
978 if (i != 0 && parser.parseComma())
979 return failure();
980 if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
981 parser.parseEqual() || parser.parseOperand(sizes[i]))
982 return failure();
983 }
984
985 return parser.parseRParen();
986}
987
988/// Parses a Launch operation.
989/// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
990/// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
991/// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
992/// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
993/// (`dynamic_shared_memory_size` ssa-use)?
994/// (`module(` symbol-ref-id `)`)?
995/// (`function(` symbol-ref-id `)`)?
996/// memory-attribution
997/// region attr-dict?
998/// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
999ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
1000 // Sizes of the grid and block.
1001 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
1002 sizes(LaunchOp::kNumConfigOperands);
1003
1004 // Region arguments to be created.
1005 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
1006 LaunchOp::kNumConfigRegionAttributes);
1007
1008 // Parse optional async dependencies.
1009 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
1010 Type asyncTokenType;
1011 if (failed(
1012 parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
1013 parser.resolveOperands(asyncDependencies, asyncTokenType,
1014 result.operands))
1015 return failure();
1016 if (parser.getNumResults() > 0) {
1017 if (!asyncTokenType)
1018 return parser.emitError(
1019 parser.getNameLoc(),
1020 "gpu.launch requires 'async' keyword to return a value");
1021 result.types.push_back(asyncTokenType);
1022 }
1023
1024 bool hasCluster = false;
1025 if (succeeded(parser.parseOptionalKeyword(LaunchOp::getClustersKeyword()))) {
1026 hasCluster = true;
1027 sizes.resize(9);
1028 regionArgs.resize(18);
1029 }
1030 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
1031 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
1032
1033 // Last three segment assigns the cluster size. In the region argument
1034 // list, this is last 6 arguments.
1035 if (hasCluster) {
1037 parser, sizesRef.drop_front(6), regionArgsRef.slice(15, 3),
1038 regionArgsRef.slice(12, 3), LaunchOp::getClustersKeyword()))
1039 return failure();
1040 }
1041 // Parse the size assignment segments: the first segment assigns grid sizes
1042 // and defines values for block identifiers; the second segment assigns block
1043 // sizes and defines values for thread identifiers. In the region argument
1044 // list, identifiers precede sizes, and block-related values precede
1045 // thread-related values.
1046 if (parser.parseKeyword(LaunchOp::getBlocksKeyword()) ||
1047 parseSizeAssignment(parser, sizesRef.take_front(3),
1048 regionArgsRef.slice(6, 3), regionArgsRef.slice(0, 3),
1049 LaunchOp::getBlocksKeyword()) ||
1050 parser.parseKeyword(LaunchOp::getThreadsKeyword()) ||
1051 parseSizeAssignment(parser, sizesRef.drop_front(3),
1052 regionArgsRef.slice(9, 3), regionArgsRef.slice(3, 3),
1053 LaunchOp::getThreadsKeyword()) ||
1054 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
1055 result.operands))
1056 return failure();
1057
1058 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
1059 bool hasDynamicSharedMemorySize = false;
1060 if (!parser.parseOptionalKeyword(
1061 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1062 hasDynamicSharedMemorySize = true;
1063 if (parser.parseOperand(dynamicSharedMemorySize) ||
1064 parser.resolveOperand(dynamicSharedMemorySize,
1065 parser.getBuilder().getI32Type(),
1066 result.operands))
1067 return failure();
1068 }
1069
1070 // Parse optional module attribute.
1071 StringRef moduleAttrName = getModuleAttrName(result.name);
1072 if (succeeded(parser.parseOptionalKeyword(moduleAttrName))) {
1073 FlatSymbolRefAttr moduleSymbol;
1074 if (parser.parseLParen() ||
1075 parser.parseAttribute(moduleSymbol, Type(), moduleAttrName,
1076 result.attributes) ||
1077 parser.parseRParen())
1078 return failure();
1079 }
1080 // Parse optional function attribute.
1081 StringRef functionAttrName = getFunctionAttrName(result.name);
1082 if (succeeded(parser.parseOptionalKeyword(functionAttrName))) {
1083 FlatSymbolRefAttr funcSymbol;
1084 if (parser.parseLParen() ||
1085 parser.parseAttribute(funcSymbol, Type(), functionAttrName,
1086 result.attributes) ||
1087 parser.parseRParen())
1088 return failure();
1089 }
1090
1091 // Parse optional cooperative keyword.
1092 if (succeeded(parser.parseOptionalKeyword("cooperative")))
1093 result.addAttribute("cooperative", parser.getBuilder().getUnitAttr());
1094
1095 // Create the region arguments: fixed launch-config args (`index`), then
1096 // workgroup / private attribution args. The workgroup count is stored in the
1097 // inherent `workgroup_attributions` attribute when non-zero.
1098 Type index = parser.getBuilder().getIndexType();
1099 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
1100 LaunchOp::kNumConfigRegionAttributes + 6, index);
1101
1102 SmallVector<OpAsmParser::Argument> regionArguments;
1103 for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1104 OpAsmParser::Argument arg;
1105 arg.ssaName = std::get<0>(ssaValueAndType);
1106 arg.type = std::get<1>(ssaValueAndType);
1107 regionArguments.push_back(arg);
1108 }
1109
1110 Builder &builder = parser.getBuilder();
1111 // Parse workgroup memory attributions.
1112 if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
1113 regionArguments)))
1114 return failure();
1115
1116 // Store the number of operands we just parsed as the number of workgroup
1117 // memory attributions.
1118 unsigned numWorkgroupAttrs = regionArguments.size() -
1119 LaunchOp::kNumConfigRegionAttributes -
1120 (hasCluster ? 6 : 0);
1121 if (numWorkgroupAttrs != 0)
1122 result.addAttribute(LaunchOp::getWorkgroupAttributionsAttrName(result.name),
1123 builder.getI64IntegerAttr(numWorkgroupAttrs));
1124
1125 // Parse private memory attributions.
1126 if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
1127 regionArguments)))
1128 return failure();
1129
1130 // Introduce the body region and parse it. The region has
1131 // kNumConfigRegionAttributes arguments that correspond to
1132 // block/thread identifiers and grid/block sizes, all having `index` type.
1133 Region *body = result.addRegion();
1134 if (parser.parseRegion(*body, regionArguments) ||
1135 parser.parseOptionalAttrDict(result.attributes))
1136 return failure();
1137
1138 SmallVector<int32_t, 11> segmentSizes(11, 1);
1139 segmentSizes.front() = asyncDependencies.size();
1140
1141 if (!hasCluster) {
1142 segmentSizes[7] = 0;
1143 segmentSizes[8] = 0;
1144 segmentSizes[9] = 0;
1145 }
1146 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1147 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1148 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1149 return success();
1150}
1151
1152/// Simplify the gpu.launch when the range of a thread or block ID is
1153/// trivially known to be one.
1154struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1155 using OpRewritePattern<LaunchOp>::OpRewritePattern;
1156 LogicalResult matchAndRewrite(LaunchOp op,
1157 PatternRewriter &rewriter) const override {
1158 // If the range implies a single value for `id`, replace `id`'s uses by
1159 // zero.
1160 Value zero;
1161 bool simplified = false;
1162 auto constPropIdUses = [&](Value id, Value size) {
1163 // Check if size is trivially one.
1164 if (!matchPattern(size, m_One()))
1165 return;
1166 if (id.getUses().empty())
1167 return;
1168 if (!simplified) {
1169 // Create a zero value the first time.
1170 OpBuilder::InsertionGuard guard(rewriter);
1171 rewriter.setInsertionPointToStart(&op.getBody().front());
1172 zero =
1173 arith::ConstantIndexOp::create(rewriter, op.getLoc(), /*value=*/0);
1174 }
1175 rewriter.replaceAllUsesWith(id, zero);
1176 simplified = true;
1177 };
1178 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1179 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1180 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1181 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1182 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1183 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1184
1185 return success(simplified);
1186 }
1187};
1188
1189void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1190 MLIRContext *context) {
1191 rewrites.add<FoldLaunchArguments>(context);
1192}
1193
1194/// Adds a new block argument that corresponds to buffers located in
1195/// workgroup memory.
1196BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1197 int64_t cur = getWorkgroupAttributions().value_or(0);
1198 setWorkgroupAttributions(std::optional<int64_t>(cur + 1));
1199 return getBody().insertArgument(
1200 getNumConfigRegionAttributes() + static_cast<unsigned>(cur), type, loc);
1201}
1202
1203/// Adds a new block argument that corresponds to buffers located in
1204/// private memory.
1205BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1206 // Buffers on the private memory always come after buffers on the workgroup
1207 // memory.
1208 return getBody().addArgument(type, loc);
1209}
1210
1211//===----------------------------------------------------------------------===//
1212// LaunchFuncOp
1213//===----------------------------------------------------------------------===//
1214
1215void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1216 SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1217 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1218 ValueRange kernelOperands, Type asyncTokenType,
1219 ValueRange asyncDependencies,
1220 std::optional<KernelDim3> clusterSize) {
1221 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1222 "expected a symbol reference with a single nested reference");
1223 result.addOperands(asyncDependencies);
1224 if (asyncTokenType)
1225 result.types.push_back(builder.getType<AsyncTokenType>());
1226
1227 // Add grid and block sizes as op operands, followed by the data operands.
1228 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1230 if (clusterSize.has_value())
1231 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1232 if (dynamicSharedMemorySize)
1233 result.addOperands(dynamicSharedMemorySize);
1234 result.addOperands(kernelOperands);
1235
1236 Properties &prop = result.getOrAddProperties<Properties>();
1237 prop.kernel = kernelSymbol;
1238 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1239 // Initialize the segment sizes to 1.
1240 llvm::fill(prop.operandSegmentSizes, 1);
1241 prop.operandSegmentSizes[0] = asyncDependencies.size();
1242 if (!clusterSize.has_value()) {
1243 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1244 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1245 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1246 }
1247 prop.operandSegmentSizes[segmentSizesLen - 3] =
1248 dynamicSharedMemorySize ? 1 : 0;
1249 prop.operandSegmentSizes[segmentSizesLen - 2] =
1250 static_cast<int32_t>(kernelOperands.size());
1251 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1252}
1253
1254void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1255 GPUFuncOp kernelFunc, KernelDim3 gridSize,
1256 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1257 ValueRange kernelOperands, Type asyncTokenType,
1258 ValueRange asyncDependencies,
1259 std::optional<KernelDim3> clusterSize) {
1260 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1261 auto kernelSymbol =
1262 SymbolRefAttr::get(kernelModule.getNameAttr(),
1263 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1264 build(builder, result, kernelSymbol, gridSize, getBlockSize,
1265 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1266 asyncDependencies, clusterSize);
1267}
1268
1269void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1270 SymbolRefAttr kernel, KernelDim3 gridSize,
1271 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1272 ValueRange kernelOperands, Value asyncObject,
1273 std::optional<KernelDim3> clusterSize) {
1274 // Add grid and block sizes as op operands, followed by the data operands.
1275 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1277 if (clusterSize.has_value())
1278 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1279 if (dynamicSharedMemorySize)
1280 result.addOperands(dynamicSharedMemorySize);
1281 result.addOperands(kernelOperands);
1282 if (asyncObject)
1283 result.addOperands(asyncObject);
1284 Properties &prop = result.getOrAddProperties<Properties>();
1285 prop.kernel = kernel;
1286 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1287 // Initialize the segment sizes to 1.
1288 llvm::fill(prop.operandSegmentSizes, 1);
1289 prop.operandSegmentSizes[0] = 0;
1290 if (!clusterSize.has_value()) {
1291 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1292 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1293 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1294 }
1295 prop.operandSegmentSizes[segmentSizesLen - 3] =
1296 dynamicSharedMemorySize ? 1 : 0;
1297 prop.operandSegmentSizes[segmentSizesLen - 2] =
1298 static_cast<int32_t>(kernelOperands.size());
1299 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1300}
1301
1302StringAttr LaunchFuncOp::getKernelModuleName() {
1303 return getKernel().getRootReference();
1304}
1305
1306StringAttr LaunchFuncOp::getKernelName() {
1307 return getKernel().getLeafReference();
1308}
1309
1310unsigned LaunchFuncOp::getNumKernelOperands() {
1311 return getKernelOperands().size();
1312}
1313
1314Value LaunchFuncOp::getKernelOperand(unsigned i) {
1315 return getKernelOperands()[i];
1316}
1317
1318KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1319 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1320 return KernelDim3{operands[0], operands[1], operands[2]};
1321}
1322
1323KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1324 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1325 return KernelDim3{operands[3], operands[4], operands[5]};
1326}
1327
1328KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1329 assert(hasClusterSize() &&
1330 "cluster size is not set, check hasClusterSize() first");
1331 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1332 return KernelDim3{operands[6], operands[7], operands[8]};
1333}
1334
1335LogicalResult LaunchFuncOp::verify() {
1336 auto module = (*this)->getParentOfType<ModuleOp>();
1337 if (!module)
1338 return emitOpError("expected to belong to a module");
1339
1340 if (!module->getAttrOfType<UnitAttr>(
1341 GPUDialect::getContainerModuleAttrName()))
1342 return emitOpError("expected the closest surrounding module to have the '" +
1343 GPUDialect::getContainerModuleAttrName() +
1344 "' attribute");
1345
1346 if (hasClusterSize()) {
1347 if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1348 getClusterSizeZ().getType() != getClusterSizeX().getType())
1349 return emitOpError()
1350 << "expects types of the cluster dimensions must be the same";
1351 }
1352
1353 if (!getAsyncDependencies().empty() && getAsyncObject())
1354 return emitOpError(
1355 "cannot have both async dependencies and an explicit async object");
1356
1357 return success();
1358}
1359
1360LogicalResult
1361LaunchFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1362 LaunchFuncOp launchOp = *this;
1363 Operation *table = SymbolTable::getNearestSymbolTable(launchOp);
1364 // GPU modules cannot be nested within each other, escape to resolve the name.
1365 if (isa<GPUModuleOp>(table))
1367
1368 // Ignore launches that are nested more or less deep than functions in the
1369 // module we are currently checking.
1370 if (!launchOp->getParentOp() ||
1371 launchOp->getParentOp()->getParentOp() != table)
1372 return success();
1373
1374 // Ignore launch ops with missing attributes here. The errors will be
1375 // reported by the verifiers of those ops.
1376 if (!launchOp->getAttrOfType<SymbolRefAttr>(
1377 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
1378 return success();
1379
1380 // Check that `launch_func` refers to a well-formed GPU kernel container.
1381 StringAttr kernelContainerName = launchOp.getKernelModuleName();
1382 Operation *kernelContainer =
1383 symbolTable.lookupNearestSymbolFrom(table, kernelContainerName);
1384 if (!kernelContainer)
1385 return launchOp.emitOpError()
1386 << "kernel container '" << kernelContainerName.getValue()
1387 << "' is undefined";
1388
1389 // If the container is a GPU binary op return success.
1390 if (isa<BinaryOp>(kernelContainer))
1391 return success();
1392
1393 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
1394 if (!kernelModule)
1395 return launchOp.emitOpError()
1396 << "kernel module '" << kernelContainerName.getValue()
1397 << "' is undefined";
1398
1399 // Check that `launch_func` refers to a well-formed kernel function.
1400 Operation *kernelFunc = symbolTable.lookupNearestSymbolFrom(
1401 kernelModule, launchOp.getKernelName());
1402 if (!kernelFunc)
1403 return launchOp.emitOpError("kernel function '")
1404 << launchOp.getKernel() << "' is undefined";
1405 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
1406 if (!kernelConvertedFunction) {
1407 InFlightDiagnostic diag = launchOp.emitOpError()
1408 << "referenced kernel '" << launchOp.getKernel()
1409 << "' is not a function";
1410 diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
1411 return diag;
1412 }
1413
1414 if (!GPUDialect::isKernel(kernelFunc))
1415 return launchOp.emitOpError("kernel function is missing the '")
1416 << GPUDialect::getKernelFuncAttrName() << "' attribute";
1417
1418 // TODO: If the kernel isn't a GPU function (which happens during separate
1419 // compilation), do not check type correspondence as it would require the
1420 // verifier to be aware of the type conversion.
1421 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
1422 if (!kernelGPUFunction)
1423 return success();
1424
1425 unsigned actualNumArguments = launchOp.getNumKernelOperands();
1426 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
1427 if (expectedNumArguments != actualNumArguments)
1428 return launchOp.emitOpError("got ")
1429 << actualNumArguments << " kernel operands but expected "
1430 << expectedNumArguments;
1431
1432 FunctionType functionType = kernelGPUFunction.getFunctionType();
1433 for (unsigned i = 0; i < expectedNumArguments; ++i) {
1434 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
1435 return launchOp.emitOpError("type of function argument ")
1436 << i << " does not match";
1437 }
1438 }
1439
1440 return success();
1441}
1442
1443static ParseResult
1445 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1446 Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1447 if (succeeded(parser.parseOptionalColon())) {
1448 if (parser.parseType(dimTy))
1449 return failure();
1450 } else {
1451 dimTy = IndexType::get(parser.getContext());
1452 }
1453 if (clusterValue.has_value()) {
1454 clusterXTy = clusterYTy = clusterZTy = dimTy;
1455 }
1456 return success();
1457}
1458
1459static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1460 Value clusterValue, Type clusterXTy,
1461 Type clusterYTy, Type clusterZTy) {
1462 if (!dimTy.isIndex())
1463 printer << ": " << dimTy;
1464}
1465
1466static ParseResult parseLaunchFuncOperands(
1467 OpAsmParser &parser,
1469 SmallVectorImpl<Type> &argTypes) {
1470 if (parser.parseOptionalKeyword("args"))
1471 return success();
1472
1473 auto parseElement = [&]() -> ParseResult {
1474 return failure(parser.parseOperand(argNames.emplace_back()) ||
1475 parser.parseColonType(argTypes.emplace_back()));
1476 };
1477
1479 parseElement, " in argument list");
1480}
1481
1483 OperandRange operands, TypeRange types) {
1484 if (operands.empty())
1485 return;
1486 printer << "args(";
1487 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1488 [&](const auto &pair) {
1489 auto [operand, type] = pair;
1490 printer << operand << " : " << type;
1491 });
1492 printer << ")";
1493}
1494
1495//===----------------------------------------------------------------------===//
1496// ShuffleOp
1497//===----------------------------------------------------------------------===//
1498
1499void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1500 int32_t offset, int32_t width, ShuffleMode mode) {
1501 build(builder, result, value,
1502 arith::ConstantOp::create(builder, result.location,
1503 builder.getI32IntegerAttr(offset)),
1504 arith::ConstantOp::create(builder, result.location,
1505 builder.getI32IntegerAttr(width)),
1506 mode);
1507}
1508
1509//===----------------------------------------------------------------------===//
1510// RotateOp
1511//===----------------------------------------------------------------------===//
1512
1513LogicalResult RotateOp::verify() {
1514 uint32_t offset = getOffset();
1515 uint32_t width = getWidth();
1516
1517 if (offset >= width) {
1518 return emitOpError() << "offset must be in the range [0, " << width << ")";
1519 }
1520
1521 return success();
1522}
1523
1524//===----------------------------------------------------------------------===//
1525// BarrierOp
1526//===----------------------------------------------------------------------===//
1527
1528LogicalResult BarrierOp::verify() {
1529 BarrierScope scope = getScope();
1530
1531 if (getNamedBarrier() && scope != BarrierScope::Workgroup)
1532 return emitOpError("named barriers require workgroup scope");
1533
1534 return success();
1535}
1536
1537/// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1538static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1539 PatternRewriter &rewriter) {
1540 auto nextOp = dyn_cast_or_null<BarrierOp>(op->getNextNode());
1541 if (!nextOp)
1542 return failure();
1543
1544 // Cannot merge barriers of different scopes.
1545 if (op.getScope() != nextOp.getScope())
1546 return failure();
1547
1548 // Cannot merge named barriers unless both refer to the same handle.
1549 if (op.getNamedBarrier() != nextOp.getNamedBarrier())
1550 return failure();
1551
1552 std::optional<ArrayAttr> thisMemfence = op.getAddressSpaces();
1553 std::optional<ArrayAttr> nextMemfence = nextOp.getAddressSpaces();
1554
1555 if (thisMemfence) {
1556 rewriter.modifyOpInPlace(op, [&]() {
1557 if (!nextMemfence) {
1558 op.removeAddressSpacesAttr();
1559 return;
1560 }
1561 // Fast path - merge where the two barriers fence the same spaces.
1562 if (*thisMemfence == *nextMemfence) {
1563 return;
1564 }
1565
1566 llvm::SmallSetVector<Attribute, 4> mergedSpaces;
1567 for (Attribute attr : *thisMemfence)
1568 mergedSpaces.insert(attr);
1569 for (Attribute attr : *nextMemfence)
1570 mergedSpaces.insert(attr);
1571 op.setAddressSpacesAttr(rewriter.getArrayAttr(mergedSpaces.takeVector()));
1572 });
1573 }
1574
1575 rewriter.eraseOp(nextOp);
1576 return success();
1577}
1578
1579void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1580 MLIRContext *context) {
1582}
1583
1584void BarrierOp::build(mlir::OpBuilder &odsBuilder,
1585 mlir::OperationState &odsState,
1586 std::optional<AddressSpace> addressSpace) {
1587 ArrayAttr addressSpacesAttr;
1588 if (addressSpace)
1589 addressSpacesAttr = odsBuilder.getArrayAttr(
1590 AddressSpaceAttr::get(odsBuilder.getContext(), addressSpace.value()));
1591 build(
1592 odsBuilder, odsState, addressSpacesAttr, /*named_barrier=*/Value{},
1593 BarrierScopeAttr::get(odsBuilder.getContext(), BarrierScope::Workgroup));
1594}
1595
1596/// Builds a barrier that causes memory operations affecting `memrefToFence` to
1597/// be completed after the barrier is concluded. Currently, this means setting
1598/// the fenced address spaces to those of the given memref if it is a gpu
1599/// address space.
1600void BarrierOp::build(OpBuilder &builder, OperationState &odsState,
1601 Value memrefToFence) {
1602 std::optional<AddressSpace> addrSpaceToFence;
1603 if (auto memrefType = dyn_cast<BaseMemRefType>(memrefToFence.getType()))
1604 if (auto addrSpaceAttr = dyn_cast_if_present<gpu::AddressSpaceAttr>(
1605 memrefType.getMemorySpace()))
1606 addrSpaceToFence = addrSpaceAttr.getValue();
1607 return build(builder, odsState, addrSpaceToFence);
1608}
1609
1610//===----------------------------------------------------------------------===//
1611// GPUFuncOp
1612//===----------------------------------------------------------------------===//
1613
1614/// Adds a new block argument that corresponds to buffers located in
1615/// workgroup memory.
1616BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1617 int64_t cur = getWorkgroupAttributions().value_or(0);
1618 setWorkgroupAttributions(std::optional<int64_t>(cur + 1));
1619 return getBody().insertArgument(
1620 getFunctionType().getNumInputs() + static_cast<unsigned>(cur), type, loc);
1621}
1622
1623/// Adds a new block argument that corresponds to buffers located in
1624/// private memory.
1625BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1626 // Buffers on the private memory always come after buffers on the workgroup
1627 // memory.
1628 return getBody().addArgument(type, loc);
1629}
1630
1631void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1632 StringRef name, FunctionType type,
1633 TypeRange workgroupAttributions,
1634 TypeRange privateAttributions,
1635 ArrayRef<NamedAttribute> attrs) {
1636 OpBuilder::InsertionGuard g(builder);
1637
1639 builder.getStringAttr(name));
1640 result.addAttribute(getFunctionTypeAttrName(result.name),
1641 TypeAttr::get(type));
1642 result.addAttribute(getWorkgroupAttributionsAttrName(result.name),
1643 builder.getI64IntegerAttr(workgroupAttributions.size()));
1644 result.addAttributes(attrs);
1645 Region *body = result.addRegion();
1646 Block *entryBlock = builder.createBlock(body);
1647
1648 // TODO: Allow passing in proper locations here.
1649 for (Type argTy : type.getInputs())
1650 entryBlock->addArgument(argTy, result.location);
1651 for (Type argTy : workgroupAttributions)
1652 entryBlock->addArgument(argTy, result.location);
1653 for (Type argTy : privateAttributions)
1654 entryBlock->addArgument(argTy, result.location);
1655}
1656
1657/// Parses a GPU function memory attribution.
1658///
1659/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1660/// (`private` `(` ssa-id-and-type-list `)`)?
1661///
1662/// Note that this function parses only one of the two similar parts, with the
1663/// keyword provided as argument.
1664static ParseResult
1665parseAttributions(OpAsmParser &parser, StringRef keyword,
1667 Attribute &attributionAttrs) {
1668 // If we could not parse the keyword, just assume empty list and succeed.
1669 if (failed(parser.parseOptionalKeyword(keyword)))
1670 return success();
1671
1672 size_t existingArgs = args.size();
1673 ParseResult result =
1675 /*allowType=*/true, /*allowAttrs=*/true);
1676 if (failed(result))
1677 return result;
1678
1679 bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1680 [](const OpAsmParser::Argument &arg) -> bool {
1681 return arg.attrs && !arg.attrs.empty();
1682 });
1683 if (!hadAttrs) {
1684 attributionAttrs = nullptr;
1685 return result;
1686 }
1687
1688 Builder &builder = parser.getBuilder();
1689 SmallVector<Attribute> attributionAttrsVec;
1690 for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1691 if (!argument.attrs)
1692 attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1693 else
1694 attributionAttrsVec.push_back(argument.attrs);
1695 }
1696 attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1697 return result;
1698}
1699
1700/// Parses a GPU function.
1701///
1702/// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1703/// (`->` function-result-list)? memory-attribution `kernel`?
1704/// function-attributes? region
1705ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
1706 SmallVector<OpAsmParser::Argument> entryArgs;
1707 SmallVector<DictionaryAttr> resultAttrs;
1708 SmallVector<Type> resultTypes;
1709 bool isVariadic;
1710
1711 // Parse the function name.
1712 StringAttr nameAttr;
1714 result.attributes))
1715 return failure();
1716
1717 auto signatureLocation = parser.getCurrentLocation();
1719 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1720 resultAttrs)))
1721 return failure();
1722
1723 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1724 return parser.emitError(signatureLocation)
1725 << "gpu.func requires named arguments";
1726
1727 // Construct the function type. More types will be added to the region, but
1728 // not to the function type.
1729 Builder &builder = parser.getBuilder();
1730
1731 SmallVector<Type> argTypes;
1732 for (auto &arg : entryArgs)
1733 argTypes.push_back(arg.type);
1734 auto type = builder.getFunctionType(argTypes, resultTypes);
1735 result.addAttribute(getFunctionTypeAttrName(result.name),
1736 TypeAttr::get(type));
1737
1739 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1740 getResAttrsAttrName(result.name));
1741
1742 Attribute workgroupAttributionAttrs;
1743 // Parse workgroup memory attributions.
1744 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1745 entryArgs, workgroupAttributionAttrs)))
1746 return failure();
1747
1748 // Store the number of operands we just parsed as the number of workgroup
1749 // memory attributions.
1750 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1751 if (numWorkgroupAttrs != 0)
1752 result.addAttribute(
1753 GPUFuncOp::getWorkgroupAttributionsAttrName(result.name),
1754 builder.getI64IntegerAttr(numWorkgroupAttrs));
1755 if (workgroupAttributionAttrs)
1756 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1757 workgroupAttributionAttrs);
1758
1759 Attribute privateAttributionAttrs;
1760 // Parse private memory attributions.
1761 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1762 entryArgs, privateAttributionAttrs)))
1763 return failure();
1764 if (privateAttributionAttrs)
1765 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1766 privateAttributionAttrs);
1767
1768 // Parse the kernel attribute if present.
1769 if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1770 result.addAttribute(GPUFuncOp::getKernelAttrName(result.name),
1771 builder.getUnitAttr());
1772
1773 // Parse attributes.
1774 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1775 return failure();
1776
1777 // Parse the region. If no argument names were provided, take all names
1778 // (including those of attributions) from the entry block.
1779 auto *body = result.addRegion();
1780 return parser.parseRegion(*body, entryArgs);
1781}
1782
1783void GPUFuncOp::print(OpAsmPrinter &p) {
1784 p << ' ';
1785 p.printSymbolName(getName());
1786
1787 FunctionType type = getFunctionType();
1788 function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1789 /*isVariadic=*/false,
1790 type.getResults());
1791
1792 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributionBBArgs(),
1793 getWorkgroupAttribAttrs().value_or(nullptr));
1794 printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1795 getPrivateAttribAttrs().value_or(nullptr));
1796 if (isKernel())
1797 p << ' ' << getKernelKeyword();
1798
1800 p, *this,
1801 {getWorkgroupAttributionsAttrName(), getKernelAttrName(),
1802 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1803 getArgAttrsAttrName(), getResAttrsAttrName(),
1804 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1805 p << ' ';
1806 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1807}
1808
1809static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1810 StringAttr attrName) {
1811 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1812 if (!allAttrs || index >= allAttrs.size())
1813 return DictionaryAttr();
1814 return llvm::cast<DictionaryAttr>(allAttrs[index]);
1815}
1816
1817DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1818 return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1819}
1820
1821DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1822 return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1823}
1824
1825static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1826 DictionaryAttr value, StringAttr attrName) {
1827 MLIRContext *ctx = op.getContext();
1828 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1829 SmallVector<Attribute> elements;
1830 if (allAttrs)
1831 elements.append(allAttrs.begin(), allAttrs.end());
1832 while (elements.size() <= index)
1833 elements.push_back(DictionaryAttr::get(ctx));
1834 if (!value)
1835 elements[index] = DictionaryAttr::get(ctx);
1836 else
1837 elements[index] = value;
1838 ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1839 op->setAttr(attrName, newValue);
1840}
1841
1842void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1843 DictionaryAttr value) {
1844 setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1845}
1846
1847void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1848 DictionaryAttr value) {
1849 setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1850}
1851
1852static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1853 StringAttr name, StringAttr attrsName) {
1854 DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1855 if (!dict)
1856 return Attribute();
1857 return dict.get(name);
1858}
1859
1860Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1861 StringAttr name) {
1862 assert(index < getNumWorkgroupAttributions() &&
1863 "index must map to a workgroup attribution");
1864 return getAttributionAttr(*this, index, name,
1865 getWorkgroupAttribAttrsAttrName());
1866}
1867
1868Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1869 StringAttr name) {
1870 assert(index < getNumPrivateAttributions() &&
1871 "index must map to a private attribution");
1872 return getAttributionAttr(*this, index, name,
1873 getPrivateAttribAttrsAttrName());
1874}
1875
1876static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1877 Attribute value, StringAttr attrsName) {
1878 MLIRContext *ctx = op.getContext();
1880 DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1881 if (oldDict)
1882 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1883
1884 bool found = false;
1885 bool mustSort = true;
1886 for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1887 if (elems[i].getName() == name) {
1888 found = true;
1889 if (!value) {
1890 std::swap(elems[i], elems[elems.size() - 1]);
1891 elems.pop_back();
1892 } else {
1893 mustSort = false;
1894 elems[i] = NamedAttribute(elems[i].getName(), value);
1895 }
1896 break;
1897 }
1898 }
1899 if (!found) {
1900 if (!value)
1901 return;
1902 elems.emplace_back(name, value);
1903 }
1904 if (mustSort) {
1905 DictionaryAttr::sortInPlace(elems);
1906 }
1907 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1908 setAttributionAttrs(op, index, newDict, attrsName);
1909}
1910
1911void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1912 Attribute value) {
1913 assert(index < getNumWorkgroupAttributions() &&
1914 "index must map to a workgroup attribution");
1915 setAttributionAttr(*this, index, name, value,
1916 getWorkgroupAttribAttrsAttrName());
1917}
1918
1919void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1920 Attribute value) {
1921 assert(index < getNumPrivateAttributions() &&
1922 "index must map to a private attribution");
1923 setAttributionAttr(*this, index, name, value,
1924 getPrivateAttribAttrsAttrName());
1925}
1926
1927LogicalResult GPUFuncOp::verifyType() {
1928 if (isKernel() && getFunctionType().getNumResults() != 0)
1929 return emitOpError() << "expected void return type for kernel function";
1930
1931 return success();
1932}
1933
1934/// Verifies the body of the function.
1935LogicalResult GPUFuncOp::verifyBody() {
1936 if (empty())
1937 return emitOpError() << "expected body with at least one block";
1938 unsigned numFuncArguments = getNumArguments();
1939 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1940 unsigned numBlockArguments = front().getNumArguments();
1941 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1942 return emitOpError() << "expected at least "
1943 << numFuncArguments + numWorkgroupAttributions
1944 << " arguments to body region";
1945
1946 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1947 for (unsigned i = 0; i < numFuncArguments; ++i) {
1948 Type blockArgType = front().getArgument(i).getType();
1949 if (funcArgTypes[i] != blockArgType)
1950 return emitOpError() << "expected body region argument #" << i
1951 << " to be of type " << funcArgTypes[i] << ", got "
1952 << blockArgType;
1953 }
1954
1955 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributionBBArgs(),
1956 GPUDialect::getWorkgroupAddressSpace())) ||
1957 failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1958 GPUDialect::getPrivateAddressSpace())))
1959 return failure();
1960
1961 return success();
1962}
1963
1964//===----------------------------------------------------------------------===//
1965// ReturnOp
1966//===----------------------------------------------------------------------===//
1967
1968LogicalResult gpu::ReturnOp::verify() {
1969 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1970
1971 FunctionType funType = function.getFunctionType();
1972
1973 if (funType.getNumResults() != getOperands().size())
1974 return emitOpError()
1975 .append("expected ", funType.getNumResults(), " result operands")
1976 .attachNote(function.getLoc())
1977 .append("return type declared here");
1978
1979 for (const auto &pair : llvm::enumerate(
1980 llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1981 auto [type, operand] = pair.value();
1982 if (type != operand.getType())
1983 return emitOpError() << "unexpected type `" << operand.getType()
1984 << "' for operand #" << pair.index();
1985 }
1986 return success();
1987}
1988
1989//===----------------------------------------------------------------------===//
1990// GPUModuleOp
1991//===----------------------------------------------------------------------===//
1992
1993void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1994 StringRef name, ArrayAttr targets,
1995 Attribute offloadingHandler) {
1996 result.addRegion()->emplaceBlock();
1997 Properties &props = result.getOrAddProperties<Properties>();
1998 if (targets)
1999 props.targets = targets;
2000 props.setSymName(builder.getStringAttr(name));
2001 props.offloadingHandler = offloadingHandler;
2002}
2003
2004void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
2005 StringRef name, ArrayRef<Attribute> targets,
2006 Attribute offloadingHandler) {
2007 build(builder, result, name,
2008 targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
2009 offloadingHandler);
2010}
2011
2012bool GPUModuleOp::hasTarget(Attribute target) {
2013 if (ArrayAttr targets = getTargetsAttr())
2014 return llvm::count(targets.getValue(), target);
2015 return false;
2016}
2017
2018void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
2019 ArrayAttr &targetsAttr = getProperties().targets;
2020 SmallVector<Attribute> targetsVector(targets);
2021 targetsAttr = ArrayAttr::get(getContext(), targetsVector);
2022}
2023
2024LogicalResult GPUModuleOp::verify() {
2025 auto targets = getOperation()->getAttrOfType<ArrayAttr>("targets");
2026
2027 if (!targets)
2028 return success();
2029
2030 for (auto target : targets) {
2031 if (auto verifyTargetAttr =
2032 llvm::dyn_cast<TargetAttrVerifyInterface>(target)) {
2033 if (verifyTargetAttr.verifyTarget(getOperation()).failed())
2034 return failure();
2035 }
2036 }
2037 return success();
2038}
2039
2040//===----------------------------------------------------------------------===//
2041// GPUBinaryOp
2042//===----------------------------------------------------------------------===//
2043void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
2044 Attribute offloadingHandler, ArrayAttr objects) {
2045 auto &properties = result.getOrAddProperties<Properties>();
2046 result.attributes.push_back(builder.getNamedAttr(
2048 properties.objects = objects;
2049 if (offloadingHandler)
2050 properties.offloadingHandler = offloadingHandler;
2051 else
2052 properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
2053}
2054
2055void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
2056 Attribute offloadingHandler, ArrayRef<Attribute> objects) {
2057 build(builder, result, name, offloadingHandler,
2058 objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
2059}
2060
2061static ParseResult parseOffloadingHandler(OpAsmParser &parser,
2062 Attribute &offloadingHandler) {
2063 if (succeeded(parser.parseOptionalLess())) {
2064 if (parser.parseAttribute(offloadingHandler))
2065 return failure();
2066 if (parser.parseGreater())
2067 return failure();
2068 }
2069 if (!offloadingHandler)
2070 offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
2071 return success();
2072}
2073
2075 Attribute offloadingHandler) {
2076 if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
2077 printer << '<' << offloadingHandler << '>';
2078}
2079
2080//===----------------------------------------------------------------------===//
2081// GPUMemcpyOp
2082//===----------------------------------------------------------------------===//
2083
2084LogicalResult MemcpyOp::verify() {
2085 auto srcType = getSrc().getType();
2086 auto dstType = getDst().getType();
2087
2088 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
2089 return emitOpError("arguments have incompatible element type");
2090
2091 if (failed(verifyCompatibleShape(srcType, dstType)))
2092 return emitOpError("arguments have incompatible shape");
2093
2094 return success();
2095}
2096
2097namespace {
2098
2099/// Erases a common case of copy ops where a destination value is used only by
2100/// the copy op, alloc and dealloc ops.
2101struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
2102 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
2103
2104 LogicalResult matchAndRewrite(MemcpyOp op,
2105 PatternRewriter &rewriter) const override {
2106 Value dest = op.getDst();
2107 Operation *destDefOp = dest.getDefiningOp();
2108 // `dest` must be defined by an op having Allocate memory effect in order to
2109 // perform the folding.
2110 if (!destDefOp ||
2112 return failure();
2113 // We can erase `op` iff `dest` has no other use apart from its
2114 // use by `op` and dealloc ops.
2115 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
2116 return user != op &&
2117 !hasSingleEffect<MemoryEffects::Free>(user, dest);
2118 }))
2119 return failure();
2120 // We can perform the folding if and only if op has a single async
2121 // dependency and produces an async token as result, or if it does not have
2122 // any async dependency and does not produce any async token result.
2123 if (op.getAsyncDependencies().size() > 1 ||
2124 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2125 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2126 return failure();
2127 rewriter.replaceOp(op, op.getAsyncDependencies());
2128 return success();
2129 }
2130};
2131
2132} // end anonymous namespace
2133
2134void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2135 MLIRContext *context) {
2136 results.add<EraseTrivialCopyOp>(context);
2137}
2138
2139//===----------------------------------------------------------------------===//
2140// GPU_SubgroupMmaLoadMatrixOp
2141//===----------------------------------------------------------------------===//
2142
2143LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2144 auto srcType = getSrcMemref().getType();
2145 auto resType = getRes().getType();
2146 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2147 auto operand = resMatrixType.getOperand();
2148 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2149
2150 if (!srcMemrefType.isLastDimUnitStride())
2151 return emitError(
2152 "expected source memref most minor dim must have unit stride");
2153
2154 if (operand != "AOp" && operand != "BOp" && operand != "COp")
2155 return emitError("only AOp, BOp and COp can be loaded");
2156
2157 return success();
2158}
2159
2160//===----------------------------------------------------------------------===//
2161// GPU_SubgroupMmaStoreMatrixOp
2162//===----------------------------------------------------------------------===//
2163
2164LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2165 auto srcType = getSrc().getType();
2166 auto dstType = getDstMemref().getType();
2167 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2168 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2169
2170 if (!dstMemrefType.isLastDimUnitStride())
2171 return emitError(
2172 "expected destination memref most minor dim must have unit stride");
2173
2174 if (srcMatrixType.getOperand() != "COp")
2175 return emitError(
2176 "expected the operand matrix being stored to have 'COp' operand type");
2177
2178 return success();
2179}
2180
2181//===----------------------------------------------------------------------===//
2182// GPU_SubgroupMmaComputeOp
2183//===----------------------------------------------------------------------===//
2184
2185LogicalResult SubgroupMmaComputeOp::verify() {
2186 enum OperandMap { A, B, C };
2187 SmallVector<MMAMatrixType, 3> opTypes;
2188 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
2189 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
2190 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
2191
2192 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
2193 opTypes[C].getOperand() != "COp")
2194 return emitError("operands must be in the order AOp, BOp, COp");
2195
2196 ArrayRef<int64_t> aShape, bShape, cShape;
2197 aShape = opTypes[A].getShape();
2198 bShape = opTypes[B].getShape();
2199 cShape = opTypes[C].getShape();
2200
2201 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2202 bShape[1] != cShape[1])
2203 return emitError("operand shapes do not satisfy matmul constraints");
2204
2205 return success();
2206}
2207
2208LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2209 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2210 return memref::foldMemRefCast(*this);
2211}
2212
2213LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2214 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2215 return memref::foldMemRefCast(*this);
2216}
2217
2218//===----------------------------------------------------------------------===//
2219// GPU_WaitOp
2220//===----------------------------------------------------------------------===//
2221
2222namespace {
2223
2224/// Remove gpu.wait op use of gpu.wait op def without async dependencies.
2225/// %t = gpu.wait async [] // No async dependencies.
2226/// ... gpu.wait ... [%t, ...] // %t can be removed.
2227struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2228public:
2230
2231 LogicalResult matchAndRewrite(WaitOp op,
2232 PatternRewriter &rewriter) const final {
2233 auto predicate = [](Value value) {
2234 auto waitOp = value.getDefiningOp<WaitOp>();
2235 return waitOp && waitOp->getNumOperands() == 0;
2236 };
2237 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2238 return failure();
2239 SmallVector<Value> validOperands;
2240 for (Value operand : op->getOperands()) {
2241 if (predicate(operand))
2242 continue;
2243 validOperands.push_back(operand);
2244 }
2245 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2246 return success();
2247 }
2248};
2249
2250/// Simplify trivial gpu.wait ops for the following patterns.
2251/// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2252/// dependencies).
2253/// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2254/// %t0.
2255/// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2256/// dependencies nor return any token.
2257struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2258public:
2260
2261 LogicalResult matchAndRewrite(WaitOp op,
2262 PatternRewriter &rewriter) const final {
2263 // Erase gpu.wait ops that neither have any async dependencies nor return
2264 // any async token.
2265 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2266 rewriter.eraseOp(op);
2267 return success();
2268 }
2269 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2270 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2271 op.getAsyncToken()) {
2272 rewriter.replaceOp(op, op.getAsyncDependencies());
2273 return success();
2274 }
2275 // Erase %t = gpu.wait async ... ops, where %t has no uses.
2276 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2277 rewriter.eraseOp(op);
2278 return success();
2279 }
2280 return failure();
2281 }
2282};
2283
2284} // end anonymous namespace
2285
2286void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2287 MLIRContext *context) {
2288 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2289}
2290
2291//===----------------------------------------------------------------------===//
2292// GPU_AllocOp
2293//===----------------------------------------------------------------------===//
2294
2295LogicalResult AllocOp::verify() {
2296 auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2297
2298 if (failed(verifyDynamicDimensionCount(getOperation(), memRefType,
2299 getDynamicSizes())))
2300 return failure();
2301
2302 unsigned numSymbols = 0;
2303 if (!memRefType.getLayout().isIdentity())
2304 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2305 if (getSymbolOperands().size() != numSymbols) {
2306 return emitOpError(
2307 "symbol operand count does not equal memref symbol count");
2308 }
2309
2310 return success();
2311}
2312
2313namespace {
2314
2315/// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2316/// `memref::AllocOp`.
2317struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2318 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2319
2320 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2321 PatternRewriter &rewriter) const override {
2322 std::optional<int64_t> index = dimOp.getConstantIndex();
2323 if (!index)
2324 return failure();
2325
2326 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2327 if (!memrefType || index.value() >= memrefType.getRank() ||
2328 !memrefType.isDynamicDim(index.value()))
2329 return failure();
2330
2331 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2332 if (!alloc)
2333 return failure();
2334
2335 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2336 memrefType.getDynamicDimIndex(index.value()));
2337 rewriter.replaceOp(dimOp, substituteOp);
2338 return success();
2339 }
2340};
2341
2342} // namespace
2343
2344void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2345 MLIRContext *context) {
2346 results.add<SimplifyDimOfAllocOp>(context);
2347}
2348
2349//===----------------------------------------------------------------------===//
2350// GPU object attribute
2351//===----------------------------------------------------------------------===//
2352
2353LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2354 Attribute target, CompilationTarget format,
2355 StringAttr object, DictionaryAttr properties,
2356 KernelTableAttr kernels) {
2357 if (!target)
2358 return emitError() << "the target attribute cannot be null";
2359 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2360 return success();
2361 return emitError() << "the target attribute must implement or promise the "
2362 "`gpu::TargetAttrInterface`";
2363}
2364
2365namespace {
2366ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2367 StringAttr &object) {
2368 std::optional<CompilationTarget> formatResult;
2369 StringRef enumKeyword;
2370 auto loc = odsParser.getCurrentLocation();
2371 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2372 formatResult = CompilationTarget::Fatbin;
2373 if (!formatResult &&
2374 (formatResult =
2375 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2376 odsParser.parseEqual())
2377 return odsParser.emitError(loc, "expected an equal sign");
2378 if (!formatResult)
2379 return odsParser.emitError(loc, "expected keyword for GPU object format");
2380 FailureOr<StringAttr> objectResult =
2381 FieldParser<StringAttr>::parse(odsParser);
2382 if (failed(objectResult))
2383 return odsParser.emitError(odsParser.getCurrentLocation(),
2384 "failed to parse GPU_ObjectAttr parameter "
2385 "'object' which is to be a `StringAttr`");
2386 format = *formatResult;
2387 object = *objectResult;
2388 return success();
2389}
2390
2391void printObject(AsmPrinter &odsParser, CompilationTarget format,
2392 StringAttr object) {
2393 if (format != CompilationTarget::Fatbin)
2394 odsParser << stringifyEnum(format) << " = ";
2395 odsParser << object;
2396}
2397} // namespace
2398
2399//===----------------------------------------------------------------------===//
2400// GPU select object attribute
2401//===----------------------------------------------------------------------===//
2402
2403LogicalResult
2404gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2405 Attribute target) {
2406 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2407 if (target) {
2408 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2409 if (intAttr.getInt() < 0) {
2410 return emitError() << "the object index must be positive";
2411 }
2412 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2413 return emitError()
2414 << "the target attribute must be a GPU Target attribute";
2415 }
2416 }
2417 return success();
2418}
2419
2420//===----------------------------------------------------------------------===//
2421// DynamicSharedMemoryOp
2422//===----------------------------------------------------------------------===//
2423
2424LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2425 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2426 return emitOpError() << "must be inside an op with symbol table";
2427
2428 MemRefType memrefType = getResultMemref().getType();
2429 // Check address space
2430 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2431 return emitOpError() << "address space must be "
2432 << gpu::AddressSpaceAttr::getMnemonic() << "<"
2433 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2434 }
2435 if (memrefType.hasStaticShape()) {
2436 return emitOpError() << "result memref type must be memref<?xi8, "
2437 "#gpu.address_space<workgroup>>";
2438 }
2439 return success();
2440}
2441
2442//===----------------------------------------------------------------------===//
2443// GPU WarpExecuteOnLane0Op
2444//===----------------------------------------------------------------------===//
2445
2446void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2447 p << "(" << getLaneid() << ")";
2448
2449 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2450 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2451 p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2452
2453 if (!getArgs().empty())
2454 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2455 if (!getResults().empty())
2456 p << " -> (" << getResults().getTypes() << ')';
2457 p << " ";
2458 p.printRegion(getRegion(),
2459 /*printEntryBlockArgs=*/true,
2460 /*printBlockTerminators=*/!getResults().empty());
2461 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2462}
2463
2464ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2465 OperationState &result) {
2466 // Create the region.
2467 result.regions.reserve(1);
2468 Region *warpRegion = result.addRegion();
2469
2470 auto &builder = parser.getBuilder();
2471 OpAsmParser::UnresolvedOperand laneId;
2472
2473 // Parse predicate operand.
2474 if (parser.parseLParen() ||
2475 parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2476 parser.parseRParen())
2477 return failure();
2478
2479 int64_t warpSize;
2480 if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2481 parser.parseRSquare())
2482 return failure();
2483 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2484 builder.getContext())),
2485 builder.getI64IntegerAttr(warpSize));
2486
2487 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2488 return failure();
2489
2490 llvm::SMLoc inputsOperandsLoc;
2491 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2492 SmallVector<Type> inputTypes;
2493 if (succeeded(parser.parseOptionalKeyword("args"))) {
2494 if (parser.parseLParen())
2495 return failure();
2496
2497 inputsOperandsLoc = parser.getCurrentLocation();
2498 if (parser.parseOperandList(inputsOperands) ||
2499 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2500 return failure();
2501 }
2502 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2503 result.operands))
2504 return failure();
2505
2506 // Parse optional results type list.
2507 if (parser.parseOptionalArrowTypeList(result.types))
2508 return failure();
2509 // Parse the region.
2510 if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2511 /*argTypes=*/{}))
2512 return failure();
2513 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2514
2515 // Parse the optional attribute list.
2516 if (parser.parseOptionalAttrDict(result.attributes))
2517 return failure();
2518 return success();
2519}
2520
2521void WarpExecuteOnLane0Op::getSuccessorRegions(
2522 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2523 if (!point.isParent()) {
2524 regions.push_back(RegionSuccessor::parent());
2525 return;
2526 }
2527
2528 // The warp region is always executed
2529 regions.push_back(RegionSuccessor(&getWarpRegion()));
2530}
2531
2532ValueRange WarpExecuteOnLane0Op::getSuccessorInputs(RegionSuccessor successor) {
2533 return successor.isParent() ? ValueRange(getResults()) : ValueRange();
2534}
2535void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2536 TypeRange resultTypes, Value laneId,
2537 int64_t warpSize) {
2538 build(builder, result, resultTypes, laneId, warpSize,
2539 /*operands=*/{}, /*argTypes=*/{});
2540}
2541
2542void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2543 TypeRange resultTypes, Value laneId,
2544 int64_t warpSize, ValueRange args,
2545 TypeRange blockArgTypes) {
2546 result.addOperands(laneId);
2547 result.addAttribute(getAttributeNames()[0],
2548 builder.getI64IntegerAttr(warpSize));
2549 result.addTypes(resultTypes);
2550 result.addOperands(args);
2551 assert(args.size() == blockArgTypes.size());
2552 OpBuilder::InsertionGuard guard(builder);
2553 Region *warpRegion = result.addRegion();
2554 Block *block = builder.createBlock(warpRegion);
2555 for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2556 block->addArgument(type, arg.getLoc());
2557}
2558
2559/// Helper check if the distributed vector type is consistent with the expanded
2560/// type and distributed size.
2561static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2562 int64_t warpSize, Operation *op) {
2563 // If the types matches there is no distribution.
2564 if (expanded == distributed)
2565 return success();
2566 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2567 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2568 if (!expandedVecType || !distributedVecType)
2569 return op->emitOpError("expected vector type for distributed operands.");
2570 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2571 expandedVecType.getElementType() != distributedVecType.getElementType())
2572 return op->emitOpError(
2573 "expected distributed vectors to have same rank and element type.");
2574
2575 SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2576 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2577 int64_t eDim = expandedVecType.getDimSize(i);
2578 int64_t dDim = distributedVecType.getDimSize(i);
2579 if (eDim == dDim)
2580 continue;
2581 if (eDim % dDim != 0)
2582 return op->emitOpError()
2583 << "expected expanded vector dimension #" << i << " (" << eDim
2584 << ") to be a multipler of the distributed vector dimension ("
2585 << dDim << ")";
2586 scales[i] = eDim / dDim;
2587 }
2588 if (llvm::product_of(scales) != warpSize)
2589 return op->emitOpError()
2590 << "incompatible distribution dimensions from " << expandedVecType
2591 << " to " << distributedVecType << " with warp size = " << warpSize;
2592
2593 return success();
2594}
2595
2596LogicalResult WarpExecuteOnLane0Op::verify() {
2597 if (getArgs().size() != getWarpRegion().getNumArguments())
2598 return emitOpError(
2599 "expected same number op arguments and block arguments.");
2600 auto yield = dyn_cast<gpu::YieldOp>(getBody()->getTerminator());
2601 if (!yield)
2602 return emitOpError("expected body to be terminated with 'gpu.yield'");
2603 if (yield.getNumOperands() != getNumResults())
2604 return emitOpError(
2605 "expected same number of yield operands and return values.");
2606 int64_t warpSize = getWarpSize();
2607 for (auto [regionArg, arg] :
2608 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2609 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2610 warpSize, getOperation())))
2611 return failure();
2612 }
2613 for (auto [yieldOperand, result] :
2614 llvm::zip_equal(yield.getOperands(), getResults())) {
2615 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2616 warpSize, getOperation())))
2617 return failure();
2618 }
2619 return success();
2620}
2621bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2622 return succeeded(
2623 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2624}
2625
2626gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2627 return cast<gpu::YieldOp>(getBody()->getTerminator());
2628}
2629
2630//===----------------------------------------------------------------------===//
2631// GPU_SubgroupBroadcastOp
2632//===----------------------------------------------------------------------===//
2633
2634void gpu::SubgroupBroadcastOp::inferResultRanges(
2635 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
2636 setResultRange(getResult(), argRanges.front());
2637}
2638
2639Speculation::Speculatability gpu::SubgroupBroadcastOp::getSpeculatability() {
2640 switch (getBroadcastType()) {
2641 case BroadcastType::first_active_lane:
2642 // Cannot speculate first_lane broadcast, because speculating it across
2643 // control flow can change the active lanes.
2645 case BroadcastType::specific_lane:
2646 // Speculation should be safe as long as we inside structured control flow.
2648 }
2649 llvm_unreachable("Unknown BroadcastType");
2650}
2651
2652LogicalResult gpu::SubgroupBroadcastOp::verify() {
2653 switch (getBroadcastType()) {
2654 case BroadcastType::first_active_lane:
2655 if (getLane())
2656 return emitOpError()
2657 << "lane can only be specified for `specific_lane` broadcast";
2658 return success();
2659 case BroadcastType::specific_lane:
2660 if (!getLane())
2661 return emitOpError()
2662 << "lane must be specified for `specific_lane` broadcast";
2663 return success();
2664 }
2665 llvm_unreachable("Unknown BroadcastType");
2666}
2667
2668OpFoldResult gpu::SubgroupBroadcastOp::fold(FoldAdaptor /*adaptor*/) {
2669 // Broadcast result is always uniform.
2670 if (auto prev = getSrc().getDefiningOp<SubgroupBroadcastOp>())
2671 return prev.getResult();
2672
2673 return nullptr;
2674}
2675
2676//===----------------------------------------------------------------------===//
2677// GPU_BallotOp
2678//===----------------------------------------------------------------------===//
2679
2680// No custom implementations needed; ballot uses default behavior from ODS.
2681
2682//===----------------------------------------------------------------------===//
2683// GPU KernelMetadataAttr
2684//===----------------------------------------------------------------------===//
2685
2686KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2687 DictionaryAttr metadata) {
2688 assert(kernel && "invalid kernel");
2689 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2690 kernel.getAllArgAttrs(), metadata);
2691}
2692
2693KernelMetadataAttr
2694KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2695 FunctionOpInterface kernel,
2696 DictionaryAttr metadata) {
2697 assert(kernel && "invalid kernel");
2698 return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2699 kernel.getAllArgAttrs(), metadata);
2700}
2701
2702KernelMetadataAttr
2703KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2704 if (attrs.empty())
2705 return *this;
2706 NamedAttrList attrList;
2707 if (DictionaryAttr dict = getMetadata())
2708 attrList.append(dict);
2709 attrList.append(attrs);
2710 return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2711 attrList.getDictionary(getContext()));
2712}
2713
2714LogicalResult
2715KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2716 StringAttr name, Type functionType,
2717 ArrayAttr argAttrs, DictionaryAttr metadata) {
2718 if (name.empty())
2719 return emitError() << "the kernel name can't be empty";
2720 if (argAttrs) {
2721 if (llvm::any_of(argAttrs, [](Attribute attr) {
2722 return !llvm::isa<DictionaryAttr>(attr);
2723 }))
2724 return emitError()
2725 << "all attributes in the array must be a dictionary attribute";
2726 }
2727 return success();
2728}
2729
2730//===----------------------------------------------------------------------===//
2731// GPU KernelTableAttr
2732//===----------------------------------------------------------------------===//
2733
2734KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2735 ArrayRef<KernelMetadataAttr> kernels,
2736 bool isSorted) {
2737 // Note that `is_sorted` is always only invoked once even with assertions ON.
2738 assert((!isSorted || llvm::is_sorted(kernels)) &&
2739 "expected a sorted kernel array");
2740 // Immediately return the attribute if the array is sorted.
2741 if (isSorted || llvm::is_sorted(kernels))
2742 return Base::get(context, kernels);
2743 // Sort the array.
2744 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2745 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2746 return Base::get(context, kernelsTmp);
2747}
2748
2749KernelTableAttr KernelTableAttr::getChecked(
2750 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2751 ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2752 // Note that `is_sorted` is always only invoked once even with assertions ON.
2753 assert((!isSorted || llvm::is_sorted(kernels)) &&
2754 "expected a sorted kernel array");
2755 // Immediately return the attribute if the array is sorted.
2756 if (isSorted || llvm::is_sorted(kernels))
2757 return Base::getChecked(emitError, context, kernels);
2758 // Sort the array.
2759 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2760 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2761 return Base::getChecked(emitError, context, kernelsTmp);
2762}
2763
2764LogicalResult
2765KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2766 ArrayRef<KernelMetadataAttr> kernels) {
2767 if (kernels.size() < 2)
2768 return success();
2769 // Check that the kernels are uniquely named.
2770 if (std::adjacent_find(kernels.begin(), kernels.end(),
2771 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2772 return l.getName() == r.getName();
2773 }) != kernels.end()) {
2774 return emitError() << "expected all kernels to be uniquely named";
2775 }
2776 return success();
2777}
2778
2779KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2780 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2781 return found ? *iterator : KernelMetadataAttr();
2782}
2783
2784KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2785 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2786 return found ? *iterator : KernelMetadataAttr();
2787}
2788
2789//===----------------------------------------------------------------------===//
2790// GPU target options
2791//===----------------------------------------------------------------------===//
2792
2807
2825
2826TypeID TargetOptions::getTypeID() const { return typeID; }
2827
2828StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2829
2833
2834StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2835
2836StringRef TargetOptions::getELFSection() const { return elfSection; }
2837
2841
2842function_ref<void(llvm::Module &)>
2846
2847function_ref<void(llvm::Module &)>
2851
2852function_ref<void(llvm::Module &)>
2856
2858 return isaCallback;
2859}
2860
2861CompilationTarget TargetOptions::getCompilationTarget() const {
2862 return compilationTarget;
2863}
2864
2866 return CompilationTarget::Fatbin;
2867}
2868
2869std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2871 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2872 llvm::StringSaver stringSaver(options.first);
2873 StringRef opts = cmdOptions;
2874 // For a correct tokenization of the command line options `opts` must be
2875 // unquoted, otherwise the tokenization function returns a single string: the
2876 // unquoted `cmdOptions` -which is not the desired behavior.
2877 // Remove any quotes if they are at the beginning and end of the string:
2878 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2879 opts.consume_front("\""), opts.consume_back("\"");
2880 if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2881 opts.consume_front("'"), opts.consume_back("'");
2882#ifdef _WIN32
2883 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2884 /*MarkEOLs=*/false);
2885#else
2886 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2887 /*MarkEOLs=*/false);
2888#endif // _WIN32
2889 return options;
2890}
2891
2892std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2896
2897std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2899 size_t startPos = cmdOptions.find(startsWith);
2900 if (startPos == std::string::npos)
2901 return {llvm::BumpPtrAllocator(), SmallVector<const char *>()};
2902
2903 auto tokenized =
2904 tokenizeCmdOptions(cmdOptions.substr(startPos + startsWith.size()));
2905 cmdOptions.resize(startPos);
2906 return tokenized;
2907}
2908
2910
2911#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2912#include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2913
2914#define GET_ATTRDEF_CLASSES
2915#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2916
2917#define GET_OP_CLASSES
2918#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2919
2920#include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values, ArrayAttr attributes={})
static LogicalResult verifyDistributedType(Type expanded, Type distributed, int64_t warpSize, Operation *op)
Helper check if the distributed vector type is consistent with the expanded type and distributed size...
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices, StringRef keyword)
static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op, PatternRewriter &rewriter)
Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition TypeID.h:323
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
UnitAttr getUnitAttr()
Definition Builders.cpp:102
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:80
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:267
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:271
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition Builders.cpp:108
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:98
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:435
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:575
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:230
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:251
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:607
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:233
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool isParent() const
Returns true if branching from the parent op.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:78
bool isIndex() const
Definition Types.cpp:56
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:139
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
function_ref< void(llvm::Module &)> optimizedLlvmIRCallback
Callback invoked with LLVM IR for the device module after LLVM optimizations but before codegen.
function_ref< void(StringRef)> getISACallback() const
Returns the callback invoked with the target ISA for the device, for example PTX assembly.
TypeID getTypeID() const
Returns the typeID.
std::string toolkitPath
Path to the target toolkit.
SymbolTable * getSymbolTable() const
Returns the result of the getSymbolTableCallback callback or a nullptr if no callback was provided.
StringRef getELFSection() const
Returns the ELF section.
StringRef getCmdOptions() const
Returns the command line options.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
function_ref< void(StringRef)> isaCallback
Callback invoked with the target ISA for the device, for example PTX assembly.
CompilationTarget compilationTarget
Compilation process target format.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
function_ref< void(llvm::Module &)> initialLlvmIRCallback
Callback invoked with the initial LLVM IR for the device module.
ArrayRef< Attribute > getLibrariesToLink() const
Returns the LLVM libraries to link to.
TargetOptions(StringRef toolkitPath={}, ArrayRef< Attribute > librariesToLink={}, StringRef cmdOptions={}, StringRef elfSection={}, CompilationTarget compilationTarget=getDefaultCompilationTarget(), function_ref< SymbolTable *()> getSymbolTableCallback={}, function_ref< void(llvm::Module &)> initialLlvmIRCallback={}, function_ref< void(llvm::Module &)> linkedLlvmIRCallback={}, function_ref< void(llvm::Module &)> optimizedLlvmIRCallback={}, function_ref< void(StringRef)> isaCallback={})
Constructor initializing the toolkit path, the list of files to link to, extra command line options,...
function_ref< void(llvm::Module &)> getOptimizedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after LLVM optimizations but before c...
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
StringRef getToolkitPath() const
Returns the toolkit path.
SmallVector< Attribute > librariesToLink
List of files to link with the LLVM module.
function_ref< void(llvm::Module &)> linkedLlvmIRCallback
Callback invoked with LLVM IR for the device module after linking the device libraries.
function_ref< void(llvm::Module &)> getInitialLlvmIRCallback() const
Returns the callback invoked with the initial LLVM IR for the device module.
function_ref< SymbolTable *()> getSymbolTableCallback
Callback for obtaining the parent symbol table of all the GPU modules being serialized.
static CompilationTarget getDefaultCompilationTarget()
Returns the default compilation target: CompilationTarget::Fatbin.
function_ref< void(llvm::Module &)> getLinkedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after linking the device libraries.
std::string elfSection
ELF Section where the binary needs to be located.
CompilationTarget getCompilationTarget() const
Returns the compilation target.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
std::pair< IteratorT, bool > findAttrSorted(IteratorT first, IteratorT last, StringRef name)
Using llvm::lower_bound requires an extra string comparison to check whether the returned iterator po...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto getChecked(function_ref< InFlightDiagnostic()> emitError, MLIRContext *context, Ts &&...params)
Helper method analogous to get, but uses getChecked when available to allow graceful failure on inval...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition GPUDialect.h:39