MLIR 23.0.0git
GPUDialect.cpp
Go to the documentation of this file.
1//===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the GPU kernel-related dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
20#include "mlir/IR/Attributes.h"
21#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinOps.h"
25#include "mlir/IR/Diagnostics.h"
27#include "mlir/IR/Matchers.h"
30#include "mlir/IR/SymbolTable.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include "llvm/Support/CommandLine.h"
39#include "llvm/Support/ErrorHandling.h"
40#include "llvm/Support/FormatVariadic.h"
41#include "llvm/Support/InterleavedRange.h"
42#include "llvm/Support/StringSaver.h"
43#include <cassert>
44#include <numeric>
45
46using namespace mlir;
47using namespace mlir::gpu;
48
49#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
50
51//===----------------------------------------------------------------------===//
52// GPU Device Mapping Attributes
53//===----------------------------------------------------------------------===//
54
55int64_t GPUBlockMappingAttr::getMappingId() const {
56 return static_cast<int64_t>(getBlock());
57}
58
59bool GPUBlockMappingAttr::isLinearMapping() const {
60 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
61}
62
63int64_t GPUBlockMappingAttr::getRelativeIndex() const {
64 return isLinearMapping()
65 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
66 : getMappingId();
67}
68
69int64_t GPUWarpgroupMappingAttr::getMappingId() const {
70 return static_cast<int64_t>(getWarpgroup());
71}
72
73bool GPUWarpgroupMappingAttr::isLinearMapping() const {
74 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
75}
76
77int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
78 return isLinearMapping()
79 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
80 : getMappingId();
81}
82
83int64_t GPUWarpMappingAttr::getMappingId() const {
84 return static_cast<int64_t>(getWarp());
85}
86
87bool GPUWarpMappingAttr::isLinearMapping() const {
88 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
89}
90
91int64_t GPUWarpMappingAttr::getRelativeIndex() const {
92 return isLinearMapping()
93 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
94 : getMappingId();
95}
96
97int64_t GPUThreadMappingAttr::getMappingId() const {
98 return static_cast<int64_t>(getThread());
99}
100
101bool GPUThreadMappingAttr::isLinearMapping() const {
102 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
103}
104
105int64_t GPUThreadMappingAttr::getRelativeIndex() const {
106 return isLinearMapping()
107 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
108 : getMappingId();
109}
110
111int64_t GPULaneMappingAttr::getMappingId() const {
112 return static_cast<int64_t>(getLane());
113}
114
115bool GPULaneMappingAttr::isLinearMapping() const {
116 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
117}
118
119int64_t GPULaneMappingAttr::getRelativeIndex() const {
120 return isLinearMapping()
121 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
122 : getMappingId();
123}
124
125int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds() const { return 64; }
126
127/// 8 4 0
128/// Example mask : 0 0 0 1 1 0 1 0 0
129///
130/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
131/// Logical id for e.g. 5 (2) constructs filter (1 << 5 - 1).
132///
133/// Example mask : 0 0 0 1 1 0 1 0 0
134/// Example filter: 0 0 0 0 1 1 1 1 1
135/// Intersection : 0 0 0 0 1 0 1 0 0
136/// PopCnt : 2
137Value GPUMappingMaskAttr::createLogicalLinearMappingId(
138 OpBuilder &b, Value physicalLinearMappingId) const {
139 Location loc = physicalLinearMappingId.getLoc();
140 Value mask =
141 arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
142 Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
143 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
144 filter = arith::SubIOp::create(b, loc, filter, one);
145 Value filteredId = arith::AndIOp::create(b, loc, mask, filter);
146 return math::CtPopOp::create(b, loc, filteredId);
147}
148
149/// 8 4 0
150/// Example mask : 0 0 0 1 1 0 1 0 0
151///
152/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
153/// Logical id for e.g. 5 (2) constructs filter (1 << 5).
154///
155/// Example mask : 0 0 0 1 1 0 1 0 0
156/// Example filter: 0 0 0 1 0 0 0 0 0
157/// Intersection : 0 0 0 1 0 0 0 0 0
158/// Cmp : 1
159Value GPUMappingMaskAttr::createIsActiveIdPredicate(
160 OpBuilder &b, Value physicalLinearMappingId) const {
161 Location loc = physicalLinearMappingId.getLoc();
162 Value mask =
163 arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
164 Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
165 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
166 Value filtered = arith::AndIOp::create(b, loc, mask, filter);
167 Value zero = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(0));
168 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ne, filtered,
169 zero);
170}
171
172int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
173 return static_cast<int64_t>(getAddressSpace());
174}
175
176bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
177 llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
178}
179
180int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
181 llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
182}
183
184//===----------------------------------------------------------------------===//
185// MMAMatrixType
186//===----------------------------------------------------------------------===//
187
189 StringRef operand) {
190 return Base::get(elementType.getContext(), shape, elementType, operand);
191}
192
195 ArrayRef<int64_t> shape, Type elementType,
196 StringRef operand) {
197 return Base::getChecked(emitError, elementType.getContext(), shape,
198 elementType, operand);
199}
200
201unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
202
204 return getImpl()->getShape();
205}
206
207Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
208
209StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
210
212 return elementType.isF16() || elementType.isF32() || elementType.isF64() ||
213 elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
214 elementType.isInteger(32);
215}
216
217LogicalResult
219 ArrayRef<int64_t> shape, Type elementType,
220 StringRef operand) {
221 if (operand != "AOp" && operand != "BOp" && operand != "COp")
222 return emitError() << "operand expected to be one of AOp, BOp or COp";
223
224 if (shape.size() != 2)
225 return emitError() << "MMAMatrixType must have exactly two dimensions";
226
227 if (!MMAMatrixType::isValidElementType(elementType))
228 return emitError()
229 << "MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
230
231 return success();
232}
233
234//===----------------------------------------------------------------------===//
235// GPUDialect
236//===----------------------------------------------------------------------===//
237
238bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
239 if (!memorySpace)
240 return false;
241 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
242 return gpuAttr.getValue() == getWorkgroupAddressSpace();
243 return false;
244}
245
246bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
247 Attribute memorySpace = type.getMemorySpace();
248 return isWorkgroupMemoryAddressSpace(memorySpace);
249}
250
251bool GPUDialect::isKernel(Operation *op) {
252 UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
253 return static_cast<bool>(isKernelAttr);
254}
255
256namespace {
257/// This class defines the interface for handling inlining with gpu
258/// operations.
259struct GPUInlinerInterface : public DialectInlinerInterface {
260 using DialectInlinerInterface::DialectInlinerInterface;
261
262 /// All gpu dialect ops can be inlined.
263 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
264 return true;
265 }
266};
267} // namespace
268
269void GPUDialect::initialize() {
270 addTypes<AsyncTokenType>();
271 addTypes<MMAMatrixType>();
272 addTypes<SparseDnTensorHandleType>();
273 addTypes<SparseSpMatHandleType>();
274 addTypes<SparseSpGEMMOpHandleType>();
275 addOperations<
276#define GET_OP_LIST
277#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
278 >();
279 addAttributes<
280#define GET_ATTRDEF_LIST
281#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
282 >();
283 addInterfaces<GPUInlinerInterface>();
284 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
285 TerminatorOp>();
286 declarePromisedInterfaces<
287 ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,
288 ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,
289 SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();
290}
291
292static std::string getSparseHandleKeyword(SparseHandleKind kind) {
293 switch (kind) {
295 return "sparse.dntensor_handle";
297 return "sparse.spmat_handle";
299 return "sparse.spgemmop_handle";
300 }
301 llvm_unreachable("unknown sparse handle kind");
302 return "";
303}
304
305Type GPUDialect::parseType(DialectAsmParser &parser) const {
306 // Parse the main keyword for the type.
307 StringRef keyword;
308 if (parser.parseKeyword(&keyword))
309 return Type();
310 MLIRContext *context = getContext();
311
312 // Handle 'async token' types.
313 if (keyword == "async.token")
314 return AsyncTokenType::get(context);
315
316 if (keyword == "mma_matrix") {
317 SMLoc beginLoc = parser.getNameLoc();
318
319 // Parse '<'.
320 if (parser.parseLess())
321 return nullptr;
322
323 // Parse the size and elementType.
324 SmallVector<int64_t> shape;
325 Type elementType;
326 if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
327 parser.parseType(elementType))
328 return nullptr;
329
330 // Parse ','
331 if (parser.parseComma())
332 return nullptr;
333
334 // Parse operand.
335 std::string operand;
336 if (failed(parser.parseOptionalString(&operand)))
337 return nullptr;
338
339 // Parse '>'.
340 if (parser.parseGreater())
341 return nullptr;
342
344 parser.getEncodedSourceLoc(beginLoc)),
345 shape, elementType, operand);
346 }
347
349 return SparseDnTensorHandleType::get(context);
351 return SparseSpMatHandleType::get(context);
353 return SparseSpGEMMOpHandleType::get(context);
354
355 parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
356 return Type();
357}
358// TODO: print refined type here. Notice that should be corresponding to the
359// parser
360void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
361 TypeSwitch<Type>(type)
362 .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
363 .Case<SparseDnTensorHandleType>([&](Type) {
365 })
366 .Case<SparseSpMatHandleType>(
368 .Case<SparseSpGEMMOpHandleType>([&](Type) {
370 })
371 .Case([&](MMAMatrixType fragTy) {
372 os << "mma_matrix<";
373 auto shape = fragTy.getShape();
374 for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
375 os << *dim << 'x';
376 os << shape.back() << 'x' << fragTy.getElementType();
377 os << ", \"" << fragTy.getOperand() << "\"" << '>';
378 })
379 .DefaultUnreachable("unexpected 'gpu' type kind");
380}
381
382static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
383 NamedAttribute attr) {
384 auto array = dyn_cast<DenseI32ArrayAttr>(attr.getValue());
385 if (!array)
386 return op->emitOpError(Twine(attr.getName()) +
387 " must be a dense i32 array");
388 if (array.size() != 3)
389 return op->emitOpError(Twine(attr.getName()) +
390 " must contain exactly 3 elements");
391 return success();
392}
393
394LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
395 NamedAttribute attr) {
396 if (attr.getName() == getKnownBlockSizeAttrHelper().getName())
397 return verifyKnownLaunchSizeAttr(op, attr);
398 if (attr.getName() == getKnownGridSizeAttrHelper().getName())
399 return verifyKnownLaunchSizeAttr(op, attr);
400 if (attr.getName() == getKnownClusterSizeAttrHelper().getName())
401 return verifyKnownLaunchSizeAttr(op, attr);
402 if (!llvm::isa<UnitAttr>(attr.getValue()) ||
403 attr.getName() != getContainerModuleAttrName())
404 return success();
405
406 auto module = dyn_cast<ModuleOp>(op);
407 if (!module)
408 return op->emitError("expected '")
409 << getContainerModuleAttrName() << "' attribute to be attached to '"
410 << ModuleOp::getOperationName() << '\'';
411
412 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
413 // Ignore launches that are nested more or less deep than functions in the
414 // module we are currently checking.
415 if (!launchOp->getParentOp() ||
416 launchOp->getParentOp()->getParentOp() != module)
417 return success();
418
419 // Ignore launch ops with missing attributes here. The errors will be
420 // reported by the verifiers of those ops.
421 if (!launchOp->getAttrOfType<SymbolRefAttr>(
422 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
423 return success();
424
425 // Check that `launch_func` refers to a well-formed GPU kernel container.
426 StringAttr kernelContainerName = launchOp.getKernelModuleName();
427 Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
428 if (!kernelContainer)
429 return launchOp.emitOpError()
430 << "kernel container '" << kernelContainerName.getValue()
431 << "' is undefined";
432
433 // If the container is a GPU binary op return success.
434 if (isa<BinaryOp>(kernelContainer))
435 return success();
436
437 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
438 if (!kernelModule)
439 return launchOp.emitOpError()
440 << "kernel module '" << kernelContainerName.getValue()
441 << "' is undefined";
442
443 // Check that `launch_func` refers to a well-formed kernel function.
444 Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
445 if (!kernelFunc)
446 return launchOp.emitOpError("kernel function '")
447 << launchOp.getKernel() << "' is undefined";
448 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
449 if (!kernelConvertedFunction) {
450 InFlightDiagnostic diag = launchOp.emitOpError()
451 << "referenced kernel '" << launchOp.getKernel()
452 << "' is not a function";
453 diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
454 return diag;
455 }
456
457 if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
458 GPUDialect::getKernelFuncAttrName()))
459 return launchOp.emitOpError("kernel function is missing the '")
460 << GPUDialect::getKernelFuncAttrName() << "' attribute";
461
462 // TODO: If the kernel isn't a GPU function (which happens during separate
463 // compilation), do not check type correspondence as it would require the
464 // verifier to be aware of the type conversion.
465 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
466 if (!kernelGPUFunction)
467 return success();
468
469 unsigned actualNumArguments = launchOp.getNumKernelOperands();
470 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
471 if (expectedNumArguments != actualNumArguments)
472 return launchOp.emitOpError("got ")
473 << actualNumArguments << " kernel operands but expected "
474 << expectedNumArguments;
475
476 auto functionType = kernelGPUFunction.getFunctionType();
477 for (unsigned i = 0; i < expectedNumArguments; ++i) {
478 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
479 return launchOp.emitOpError("type of function argument ")
480 << i << " does not match";
481 }
482 }
483
484 return success();
485 });
486
487 return walkResult.wasInterrupted() ? failure() : success();
488}
489
490/// Parses an optional list of async operands with an optional leading keyword.
491/// (`async`)? (`[` ssa-id-list `]`)?
492///
493/// This method is used by the tablegen assembly format for async ops as well.
494static ParseResult parseAsyncDependencies(
495 OpAsmParser &parser, Type &asyncTokenType,
497 auto loc = parser.getCurrentLocation();
498 if (succeeded(parser.parseOptionalKeyword("async"))) {
499 if (parser.getNumResults() == 0)
500 return parser.emitError(loc, "needs to be named when marked 'async'");
501 asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
502 }
503 return parser.parseOperandList(asyncDependencies,
505}
506
507/// Prints optional async dependencies with its leading keyword.
508/// (`async`)? (`[` ssa-id-list `]`)?
509// Used by the tablegen assembly format for several async ops.
511 Type asyncTokenType,
512 OperandRange asyncDependencies) {
513 if (asyncTokenType)
514 printer << "async";
515 if (asyncDependencies.empty())
516 return;
517 if (asyncTokenType)
518 printer << ' ';
519 printer << llvm::interleaved_array(asyncDependencies);
520}
521
522// GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
523/// Parses a GPU function memory attribution.
524///
525/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
526/// (`private` `(` ssa-id-and-type-list `)`)?
527///
528/// Note that this function parses only one of the two similar parts, with the
529/// keyword provided as argument.
530static ParseResult
531parseAttributions(OpAsmParser &parser, StringRef keyword,
533 // If we could not parse the keyword, just assume empty list and succeed.
534 if (failed(parser.parseOptionalKeyword(keyword)))
535 return success();
536
538 /*allowType=*/true);
539}
540
541static void printAttributions(OpAsmPrinter &p, StringRef keyword,
543 ArrayAttr attributes = {}) {
544 if (values.empty())
545 return;
546
547 p << ' ' << keyword << '(';
548 llvm::interleaveComma(
549 llvm::enumerate(values), p, [&p, attributes](auto pair) {
550 BlockArgument v = pair.value();
551 p << v << " : " << v.getType();
552
553 size_t attributionIndex = pair.index();
554 DictionaryAttr attrs;
555 if (attributes && attributionIndex < attributes.size())
556 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
557 if (attrs)
558 p.printOptionalAttrDict(attrs.getValue());
559 });
560 p << ')';
561}
562
563/// Verifies a GPU function memory attribution.
564static LogicalResult verifyAttributions(Operation *op,
565 ArrayRef<BlockArgument> attributions,
566 gpu::AddressSpace memorySpace) {
567 for (Value v : attributions) {
568 auto type = llvm::dyn_cast<MemRefType>(v.getType());
569 if (!type)
570 return op->emitOpError() << "expected memref type in attribution";
571
572 // We can only verify the address space if it hasn't already been lowered
573 // from the AddressSpaceAttr to a target-specific numeric value.
574 auto addressSpace =
575 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
576 if (!addressSpace)
577 continue;
578 if (addressSpace.getValue() != memorySpace)
579 return op->emitOpError()
580 << "expected memory space " << stringifyAddressSpace(memorySpace)
581 << " in attribution";
582 }
583 return success();
584}
585
586//===----------------------------------------------------------------------===//
587// AllReduceOp
588//===----------------------------------------------------------------------===//
589
590static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
591 Type resType) {
592 using Kind = gpu::AllReduceOperation;
593 if (llvm::is_contained(
594 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
595 opName)) {
596 if (!isa<FloatType>(resType))
597 return failure();
598 }
599
600 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
601 Kind::AND, Kind::OR, Kind::XOR},
602 opName)) {
603 if (!isa<IntegerType>(resType))
604 return failure();
605 }
606
607 return success();
608}
609
610LogicalResult gpu::AllReduceOp::verifyRegions() {
611 if (getBody().empty() != getOp().has_value())
612 return emitError("expected either an op attribute or a non-empty body");
613 if (!getBody().empty()) {
614 if (getBody().getNumArguments() != 2)
615 return emitError("expected two region arguments");
616 for (auto argument : getBody().getArguments()) {
617 if (argument.getType() != getType())
618 return emitError("incorrect region argument type");
619 }
620 unsigned yieldCount = 0;
621 for (Block &block : getBody()) {
622 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
623 if (yield.getNumOperands() != 1)
624 return emitError("expected one gpu.yield operand");
625 if (yield.getOperand(0).getType() != getType())
626 return emitError("incorrect gpu.yield type");
627 ++yieldCount;
628 }
629 }
630 if (yieldCount == 0)
631 return emitError("expected gpu.yield op in region");
632 } else {
633 gpu::AllReduceOperation opName = *getOp();
634 if (failed(verifyReduceOpAndType(opName, getType()))) {
635 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
636 << "` reduction operation is not compatible with type "
637 << getType();
638 }
639 }
640
641 return success();
642}
643
645 auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
646 if (!launchOp)
647 return false;
648
649 Region &body = launchOp.getBody();
650 assert(!body.empty() && "Invalid region");
651
652 // Only convert ops in gpu::launch entry block for now.
653 return op->getBlock() == &body.front();
654}
655
656OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
657 if (!getUniform() && canMakeGroupOpUniform(*this)) {
658 setUniform(true);
659 return getResult();
660 }
661
662 return nullptr;
663}
664
665// TODO: Support optional custom attributes (without dialect prefix).
666static ParseResult parseAllReduceOperation(AsmParser &parser,
667 AllReduceOperationAttr &attr) {
668 StringRef enumStr;
669 if (!parser.parseOptionalKeyword(&enumStr)) {
670 std::optional<AllReduceOperation> op =
671 gpu::symbolizeAllReduceOperation(enumStr);
672 if (!op)
673 return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
674 attr = AllReduceOperationAttr::get(parser.getContext(), *op);
675 }
676 return success();
677}
678
680 AllReduceOperationAttr attr) {
681 if (attr)
682 attr.print(printer);
683}
684
685//===----------------------------------------------------------------------===//
686// SubgroupReduceOp
687//===----------------------------------------------------------------------===//
688
689LogicalResult gpu::SubgroupReduceOp::verify() {
690 Type elemType = getType();
691 if (auto vecTy = dyn_cast<VectorType>(elemType)) {
692 if (vecTy.isScalable())
693 return emitOpError() << "is not compatible with scalable vector types";
694
695 elemType = vecTy.getElementType();
696 }
697
698 gpu::AllReduceOperation opName = getOp();
699 if (failed(verifyReduceOpAndType(opName, elemType))) {
700 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
701 << "` reduction operation is not compatible with type "
702 << getType();
703 }
704
705 auto clusterSize = getClusterSize();
706 if (clusterSize) {
707 uint32_t size = *clusterSize;
708 if (!llvm::isPowerOf2_32(size)) {
709 return emitOpError() << "cluster size " << size
710 << " is not a power of two";
711 }
712 }
713
714 uint32_t stride = getClusterStride();
715 if (stride != 1 && !clusterSize) {
716 return emitOpError() << "cluster stride can only be specified if cluster "
717 "size is specified";
718 }
719 if (!llvm::isPowerOf2_32(stride)) {
720 return emitOpError() << "cluster stride " << stride
721 << " is not a power of two";
722 }
723
724 return success();
725}
726
727OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
728 if (getClusterSize() == 1)
729 return getValue();
730
731 if (!getUniform() && canMakeGroupOpUniform(*this)) {
732 setUniform(true);
733 return getResult();
734 }
735
736 return nullptr;
737}
738
739//===----------------------------------------------------------------------===//
740// AsyncOpInterface
741//===----------------------------------------------------------------------===//
742
744 op->insertOperands(0, {token});
745 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
746 return;
747 auto attrName =
749 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
750
751 // Async dependencies is the only variadic operand.
752 if (!sizeAttr)
753 return;
754
755 SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
756 ++sizes.front();
757 op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
758}
759
760//===----------------------------------------------------------------------===//
761// LaunchOp
762//===----------------------------------------------------------------------===//
763
764void LaunchOp::build(OpBuilder &builder, OperationState &result,
765 Value gridSizeX, Value gridSizeY, Value gridSizeZ,
766 Value getBlockSizeX, Value getBlockSizeY,
767 Value getBlockSizeZ, Value dynamicSharedMemorySize,
768 Type asyncTokenType, ValueRange asyncDependencies,
769 TypeRange workgroupAttributions,
770 TypeRange privateAttributions, Value clusterSizeX,
771 Value clusterSizeY, Value clusterSizeZ,
772 FlatSymbolRefAttr module, FlatSymbolRefAttr function) {
773 OpBuilder::InsertionGuard g(builder);
774
775 // Add a WorkGroup attribution attribute. This attribute is required to
776 // identify private attributions in the list of block argguments.
777 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
778 builder.getI64IntegerAttr(workgroupAttributions.size()));
779
780 // Add Op operands.
781 result.addOperands(asyncDependencies);
782 if (asyncTokenType)
783 result.types.push_back(builder.getType<AsyncTokenType>());
784
785 // Add grid and block sizes as op operands, followed by the data operands.
786 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
787 getBlockSizeY, getBlockSizeZ});
788 if (clusterSizeX)
789 result.addOperands(clusterSizeX);
790 if (clusterSizeY)
791 result.addOperands(clusterSizeY);
792 if (clusterSizeZ)
793 result.addOperands(clusterSizeZ);
794 if (dynamicSharedMemorySize)
795 result.addOperands(dynamicSharedMemorySize);
796
797 // Add optional module and function attributes.
798 if (module)
799 result.addAttribute(getModuleAttrName(result.name), module);
800 if (function)
801 result.addAttribute(getFunctionAttrName(result.name), function);
802
803 // Create a kernel body region with kNumConfigRegionAttributes + N memory
804 // attributions, where the first kNumConfigRegionAttributes arguments have
805 // `index` type and the rest have the same types as the data operands.
806 Region *kernelRegion = result.addRegion();
807 Block *body = builder.createBlock(kernelRegion);
808 // TODO: Allow passing in proper locations here.
809 for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
810 body->addArgument(builder.getIndexType(), result.location);
811 // Add WorkGroup & Private attributions to the region arguments.
812 for (Type argTy : workgroupAttributions)
813 body->addArgument(argTy, result.location);
814 for (Type argTy : privateAttributions)
815 body->addArgument(argTy, result.location);
816 // Fill OperandSegmentSize Attribute.
817 SmallVector<int32_t, 11> segmentSizes(11, 1);
818 segmentSizes.front() = asyncDependencies.size();
819 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
820 segmentSizes[7] = clusterSizeX ? 1 : 0;
821 segmentSizes[8] = clusterSizeY ? 1 : 0;
822 segmentSizes[9] = clusterSizeZ ? 1 : 0;
823 result.addAttribute(getOperandSegmentSizeAttr(),
824 builder.getDenseI32ArrayAttr(segmentSizes));
825}
826
827KernelDim3 LaunchOp::getBlockIds() {
828 assert(!getBody().empty() && "LaunchOp body must not be empty.");
829 auto args = getBody().getArguments();
830 return KernelDim3{args[0], args[1], args[2]};
831}
832
833KernelDim3 LaunchOp::getThreadIds() {
834 assert(!getBody().empty() && "LaunchOp body must not be empty.");
835 auto args = getBody().getArguments();
836 return KernelDim3{args[3], args[4], args[5]};
837}
838
839KernelDim3 LaunchOp::getGridSize() {
840 assert(!getBody().empty() && "LaunchOp body must not be empty.");
841 auto args = getBody().getArguments();
842 return KernelDim3{args[6], args[7], args[8]};
843}
844
845KernelDim3 LaunchOp::getBlockSize() {
846 assert(!getBody().empty() && "LaunchOp body must not be empty.");
847 auto args = getBody().getArguments();
848 return KernelDim3{args[9], args[10], args[11]};
849}
850
851std::optional<KernelDim3> LaunchOp::getClusterIds() {
852 assert(!getBody().empty() && "LaunchOp body must not be empty.");
853 if (!hasClusterSize())
854 return std::nullopt;
855 auto args = getBody().getArguments();
856 return KernelDim3{args[12], args[13], args[14]};
857}
858
859std::optional<KernelDim3> LaunchOp::getClusterSize() {
860 assert(!getBody().empty() && "LaunchOp body must not be empty.");
861 if (!hasClusterSize())
862 return std::nullopt;
863 auto args = getBody().getArguments();
864 return KernelDim3{args[15], args[16], args[17]};
865}
866
867KernelDim3 LaunchOp::getGridSizeOperandValues() {
868 auto operands = getOperands().drop_front(getAsyncDependencies().size());
869 return KernelDim3{operands[0], operands[1], operands[2]};
870}
871
872KernelDim3 LaunchOp::getBlockSizeOperandValues() {
873 auto operands = getOperands().drop_front(getAsyncDependencies().size());
874 return KernelDim3{operands[3], operands[4], operands[5]};
875}
876
877std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
878 auto operands = getOperands().drop_front(getAsyncDependencies().size());
879 if (!hasClusterSize())
880 return std::nullopt;
881 return KernelDim3{operands[6], operands[7], operands[8]};
882}
883
884LogicalResult LaunchOp::verify() {
885 if (!(hasClusterSize()) &&
886 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
887 return emitOpError() << "cluster size must be all present";
888 return success();
889}
890
891LogicalResult LaunchOp::verifyRegions() {
892 // Kernel launch takes kNumConfigOperands leading operands for grid/block
893 // sizes and transforms them into kNumConfigRegionAttributes region arguments
894 // for block/thread identifiers and grid/block sizes.
895 if (!getBody().empty()) {
896 if (getBody().getNumArguments() <
897 kNumConfigRegionAttributes + getNumWorkgroupAttributions())
898 return emitOpError("unexpected number of region arguments");
899 }
900
901 // Verify Attributions Address Spaces.
902 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
903 GPUDialect::getWorkgroupAddressSpace())) ||
904 failed(verifyAttributions(getOperation(), getPrivateAttributions(),
905 GPUDialect::getPrivateAddressSpace())))
906 return failure();
907
908 // Block terminators without successors are expected to exit the kernel region
909 // and must be `gpu.terminator`.
910 for (Block &block : getBody()) {
911 if (block.empty())
912 continue;
913 if (block.back().getNumSuccessors() != 0)
914 continue;
915 if (!isa<gpu::TerminatorOp>(&block.back())) {
916 return block.back()
917 .emitError()
918 .append("expected '", gpu::TerminatorOp::getOperationName(),
919 "' or a terminator with successors")
920 .attachNote(getLoc())
921 .append("in '", LaunchOp::getOperationName(), "' body region");
922 }
923 }
924
925 if (getNumResults() == 0 && getAsyncToken())
926 return emitOpError("needs to be named when async keyword is specified");
927
928 return success();
929}
930
931// Pretty-print the kernel grid/block size assignment as
932// (%iter-x, %iter-y, %iter-z) in
933// (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
934// where %size-* and %iter-* will correspond to the body region arguments.
936 KernelDim3 operands, KernelDim3 ids) {
937 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
938 p << size.x << " = " << operands.x << ", ";
939 p << size.y << " = " << operands.y << ", ";
940 p << size.z << " = " << operands.z << ')';
941}
942
943void LaunchOp::print(OpAsmPrinter &p) {
944 if (getAsyncToken()) {
945 p << " async";
946 if (!getAsyncDependencies().empty())
947 p << " [" << getAsyncDependencies() << ']';
948 }
949 // Print the launch configuration.
950 if (hasClusterSize()) {
951 p << ' ' << getClustersKeyword();
952 printSizeAssignment(p, getClusterSize().value(),
953 getClusterSizeOperandValues().value(),
954 getClusterIds().value());
955 }
956 p << ' ' << getBlocksKeyword();
957 printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
958 getBlockIds());
959 p << ' ' << getThreadsKeyword();
960 printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
961 getThreadIds());
962 if (getDynamicSharedMemorySize())
963 p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
964 << getDynamicSharedMemorySize();
965
966 // Print optional module attribute.
967 StringRef moduleAttrName = getModuleAttrName();
968 if (auto module = getModule()) {
969 p << ' ' << moduleAttrName << '(';
970 p.printSymbolName(*module);
971 p << ')';
972 }
973 // Print optional function attribute.
974 StringRef functionAttrName = getFunctionAttrName();
975 if (auto function = getFunction()) {
976 p << ' ' << functionAttrName << '(';
977 p.printSymbolName(*function);
978 p << ')';
979 }
980
981 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
982 printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
983
984 p << ' ';
985
986 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
987 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
988 LaunchOp::getOperandSegmentSizeAttr(),
989 getNumWorkgroupAttributionsAttrName(),
990 moduleAttrName, functionAttrName});
991}
992
993// Parse the size assignment blocks for blocks and threads. These have the form
994// (%region_arg, %region_arg, %region_arg) in
995// (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
996// where %region_arg are percent-identifiers for the region arguments to be
997// introduced further (SSA defs), and %operand are percent-identifiers for the
998// SSA value uses.
999static ParseResult
1004 StringRef keyword) {
1005 assert(indices.size() == 3 && "space for three indices expected");
1008 /*allowResultNumber=*/false) ||
1009 parser.parseKeyword("in") || parser.parseLParen())
1010 return failure();
1011
1012 if (args.size() != 3) {
1013 return parser.emitError(parser.getNameLoc())
1014 << keyword << " expects 3 arguments, but got " << args.size();
1015 }
1016 std::move(args.begin(), args.end(), indices.begin());
1017
1018 for (int i = 0; i < 3; ++i) {
1019 if (i != 0 && parser.parseComma())
1020 return failure();
1021 if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
1022 parser.parseEqual() || parser.parseOperand(sizes[i]))
1023 return failure();
1024 }
1025
1026 return parser.parseRParen();
1027}
1028
1029/// Parses a Launch operation.
1030/// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
1031/// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
1032/// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
1033/// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
1034/// (`dynamic_shared_memory_size` ssa-use)?
1035/// (`module(` symbol-ref-id `)`)?
1036/// (`function(` symbol-ref-id `)`)?
1037/// memory-attribution
1038/// region attr-dict?
1039/// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
1040ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
1041 // Sizes of the grid and block.
1042 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
1043 sizes(LaunchOp::kNumConfigOperands);
1044
1045 // Region arguments to be created.
1046 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
1047 LaunchOp::kNumConfigRegionAttributes);
1048
1049 // Parse optional async dependencies.
1050 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
1051 Type asyncTokenType;
1052 if (failed(
1053 parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
1054 parser.resolveOperands(asyncDependencies, asyncTokenType,
1055 result.operands))
1056 return failure();
1057 if (parser.getNumResults() > 0) {
1058 if (!asyncTokenType)
1059 return parser.emitError(
1060 parser.getNameLoc(),
1061 "gpu.launch requires 'async' keyword to return a value");
1062 result.types.push_back(asyncTokenType);
1063 }
1064
1065 bool hasCluster = false;
1066 if (succeeded(parser.parseOptionalKeyword(LaunchOp::getClustersKeyword()))) {
1067 hasCluster = true;
1068 sizes.resize(9);
1069 regionArgs.resize(18);
1070 }
1071 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
1072 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
1073
1074 // Last three segment assigns the cluster size. In the region argument
1075 // list, this is last 6 arguments.
1076 if (hasCluster) {
1078 parser, sizesRef.drop_front(6), regionArgsRef.slice(15, 3),
1079 regionArgsRef.slice(12, 3), LaunchOp::getClustersKeyword()))
1080 return failure();
1081 }
1082 // Parse the size assignment segments: the first segment assigns grid sizes
1083 // and defines values for block identifiers; the second segment assigns block
1084 // sizes and defines values for thread identifiers. In the region argument
1085 // list, identifiers precede sizes, and block-related values precede
1086 // thread-related values.
1087 if (parser.parseKeyword(LaunchOp::getBlocksKeyword()) ||
1088 parseSizeAssignment(parser, sizesRef.take_front(3),
1089 regionArgsRef.slice(6, 3), regionArgsRef.slice(0, 3),
1090 LaunchOp::getBlocksKeyword()) ||
1091 parser.parseKeyword(LaunchOp::getThreadsKeyword()) ||
1092 parseSizeAssignment(parser, sizesRef.drop_front(3),
1093 regionArgsRef.slice(9, 3), regionArgsRef.slice(3, 3),
1094 LaunchOp::getThreadsKeyword()) ||
1095 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
1096 result.operands))
1097 return failure();
1098
1099 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
1100 bool hasDynamicSharedMemorySize = false;
1101 if (!parser.parseOptionalKeyword(
1102 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1103 hasDynamicSharedMemorySize = true;
1104 if (parser.parseOperand(dynamicSharedMemorySize) ||
1105 parser.resolveOperand(dynamicSharedMemorySize,
1106 parser.getBuilder().getI32Type(),
1107 result.operands))
1108 return failure();
1109 }
1110
1111 // Parse optional module attribute.
1112 StringRef moduleAttrName = getModuleAttrName(result.name);
1113 if (succeeded(parser.parseOptionalKeyword(moduleAttrName))) {
1114 FlatSymbolRefAttr moduleSymbol;
1115 if (parser.parseLParen() ||
1116 parser.parseAttribute(moduleSymbol, Type(), moduleAttrName,
1117 result.attributes) ||
1118 parser.parseRParen())
1119 return failure();
1120 }
1121 // Parse optional function attribute.
1122 StringRef functionAttrName = getFunctionAttrName(result.name);
1123 if (succeeded(parser.parseOptionalKeyword(functionAttrName))) {
1124 FlatSymbolRefAttr funcSymbol;
1125 if (parser.parseLParen() ||
1126 parser.parseAttribute(funcSymbol, Type(), functionAttrName,
1127 result.attributes) ||
1128 parser.parseRParen())
1129 return failure();
1130 }
1131
1132 // Create the region arguments, it has kNumConfigRegionAttributes arguments
1133 // that correspond to block/thread identifiers and grid/block sizes, all
1134 // having `index` type, a variadic number of WorkGroup Attributions and
1135 // a variadic number of Private Attributions. The number of WorkGroup
1136 // Attributions is stored in the attr with name:
1137 // LaunchOp::getNumWorkgroupAttributionsAttrName().
1138 Type index = parser.getBuilder().getIndexType();
1139 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
1140 LaunchOp::kNumConfigRegionAttributes + 6, index);
1141
1142 SmallVector<OpAsmParser::Argument> regionArguments;
1143 for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1144 OpAsmParser::Argument arg;
1145 arg.ssaName = std::get<0>(ssaValueAndType);
1146 arg.type = std::get<1>(ssaValueAndType);
1147 regionArguments.push_back(arg);
1148 }
1149
1150 Builder &builder = parser.getBuilder();
1151 // Parse workgroup memory attributions.
1152 if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
1153 regionArguments)))
1154 return failure();
1155
1156 // Store the number of operands we just parsed as the number of workgroup
1157 // memory attributions.
1158 unsigned numWorkgroupAttrs = regionArguments.size() -
1159 LaunchOp::kNumConfigRegionAttributes -
1160 (hasCluster ? 6 : 0);
1161 result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1162 builder.getI64IntegerAttr(numWorkgroupAttrs));
1163
1164 // Parse private memory attributions.
1165 if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
1166 regionArguments)))
1167 return failure();
1168
1169 // Introduce the body region and parse it. The region has
1170 // kNumConfigRegionAttributes arguments that correspond to
1171 // block/thread identifiers and grid/block sizes, all having `index` type.
1172 Region *body = result.addRegion();
1173 if (parser.parseRegion(*body, regionArguments) ||
1174 parser.parseOptionalAttrDict(result.attributes))
1175 return failure();
1176
1177 SmallVector<int32_t, 11> segmentSizes(11, 1);
1178 segmentSizes.front() = asyncDependencies.size();
1179
1180 if (!hasCluster) {
1181 segmentSizes[7] = 0;
1182 segmentSizes[8] = 0;
1183 segmentSizes[9] = 0;
1184 }
1185 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1186 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1187 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1188 return success();
1189}
1190
1191/// Simplify the gpu.launch when the range of a thread or block ID is
1192/// trivially known to be one.
1193struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1194 using OpRewritePattern<LaunchOp>::OpRewritePattern;
1195 LogicalResult matchAndRewrite(LaunchOp op,
1196 PatternRewriter &rewriter) const override {
1197 // If the range implies a single value for `id`, replace `id`'s uses by
1198 // zero.
1199 Value zero;
1200 bool simplified = false;
1201 auto constPropIdUses = [&](Value id, Value size) {
1202 // Check if size is trivially one.
1203 if (!matchPattern(size, m_One()))
1204 return;
1205 if (id.getUses().empty())
1206 return;
1207 if (!simplified) {
1208 // Create a zero value the first time.
1209 OpBuilder::InsertionGuard guard(rewriter);
1210 rewriter.setInsertionPointToStart(&op.getBody().front());
1211 zero =
1212 arith::ConstantIndexOp::create(rewriter, op.getLoc(), /*value=*/0);
1213 }
1214 rewriter.replaceAllUsesWith(id, zero);
1215 simplified = true;
1216 };
1217 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1218 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1219 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1220 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1221 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1222 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1223
1224 return success(simplified);
1225 }
1226};
1227
1228void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1229 MLIRContext *context) {
1230 rewrites.add<FoldLaunchArguments>(context);
1231}
1232
1233/// Adds a new block argument that corresponds to buffers located in
1234/// workgroup memory.
1235BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1236 auto attrName = getNumWorkgroupAttributionsAttrName();
1237 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1238 (*this)->setAttr(attrName,
1239 IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1240 return getBody().insertArgument(
1241 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1242}
1243
1244/// Adds a new block argument that corresponds to buffers located in
1245/// private memory.
1246BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1247 // Buffers on the private memory always come after buffers on the workgroup
1248 // memory.
1249 return getBody().addArgument(type, loc);
1250}
1251
1252//===----------------------------------------------------------------------===//
1253// LaunchFuncOp
1254//===----------------------------------------------------------------------===//
1255
1256void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1257 SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1258 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1259 ValueRange kernelOperands, Type asyncTokenType,
1260 ValueRange asyncDependencies,
1261 std::optional<KernelDim3> clusterSize) {
1262 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1263 "expected a symbol reference with a single nested reference");
1264 result.addOperands(asyncDependencies);
1265 if (asyncTokenType)
1266 result.types.push_back(builder.getType<AsyncTokenType>());
1267
1268 // Add grid and block sizes as op operands, followed by the data operands.
1269 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1271 if (clusterSize.has_value())
1272 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1273 if (dynamicSharedMemorySize)
1274 result.addOperands(dynamicSharedMemorySize);
1275 result.addOperands(kernelOperands);
1276
1277 Properties &prop = result.getOrAddProperties<Properties>();
1278 prop.kernel = kernelSymbol;
1279 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1280 // Initialize the segment sizes to 1.
1281 llvm::fill(prop.operandSegmentSizes, 1);
1282 prop.operandSegmentSizes[0] = asyncDependencies.size();
1283 if (!clusterSize.has_value()) {
1284 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1285 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1286 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1287 }
1288 prop.operandSegmentSizes[segmentSizesLen - 3] =
1289 dynamicSharedMemorySize ? 1 : 0;
1290 prop.operandSegmentSizes[segmentSizesLen - 2] =
1291 static_cast<int32_t>(kernelOperands.size());
1292 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1293}
1294
1295void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1296 GPUFuncOp kernelFunc, KernelDim3 gridSize,
1297 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1298 ValueRange kernelOperands, Type asyncTokenType,
1299 ValueRange asyncDependencies,
1300 std::optional<KernelDim3> clusterSize) {
1301 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1302 auto kernelSymbol =
1303 SymbolRefAttr::get(kernelModule.getNameAttr(),
1304 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1305 build(builder, result, kernelSymbol, gridSize, getBlockSize,
1306 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1307 asyncDependencies, clusterSize);
1308}
1309
1310void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1311 SymbolRefAttr kernel, KernelDim3 gridSize,
1312 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1313 ValueRange kernelOperands, Value asyncObject,
1314 std::optional<KernelDim3> clusterSize) {
1315 // Add grid and block sizes as op operands, followed by the data operands.
1316 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1318 if (clusterSize.has_value())
1319 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1320 if (dynamicSharedMemorySize)
1321 result.addOperands(dynamicSharedMemorySize);
1322 result.addOperands(kernelOperands);
1323 if (asyncObject)
1324 result.addOperands(asyncObject);
1325 Properties &prop = result.getOrAddProperties<Properties>();
1326 prop.kernel = kernel;
1327 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1328 // Initialize the segment sizes to 1.
1329 llvm::fill(prop.operandSegmentSizes, 1);
1330 prop.operandSegmentSizes[0] = 0;
1331 if (!clusterSize.has_value()) {
1332 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1333 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1334 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1335 }
1336 prop.operandSegmentSizes[segmentSizesLen - 3] =
1337 dynamicSharedMemorySize ? 1 : 0;
1338 prop.operandSegmentSizes[segmentSizesLen - 2] =
1339 static_cast<int32_t>(kernelOperands.size());
1340 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1341}
1342
1343StringAttr LaunchFuncOp::getKernelModuleName() {
1344 return getKernel().getRootReference();
1345}
1346
1347StringAttr LaunchFuncOp::getKernelName() {
1348 return getKernel().getLeafReference();
1349}
1350
1351unsigned LaunchFuncOp::getNumKernelOperands() {
1352 return getKernelOperands().size();
1353}
1354
1355Value LaunchFuncOp::getKernelOperand(unsigned i) {
1356 return getKernelOperands()[i];
1357}
1358
1359KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1360 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1361 return KernelDim3{operands[0], operands[1], operands[2]};
1362}
1363
1364KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1365 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1366 return KernelDim3{operands[3], operands[4], operands[5]};
1367}
1368
1369KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1370 assert(hasClusterSize() &&
1371 "cluster size is not set, check hasClusterSize() first");
1372 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1373 return KernelDim3{operands[6], operands[7], operands[8]};
1374}
1375
1376LogicalResult LaunchFuncOp::verify() {
1377 auto module = (*this)->getParentOfType<ModuleOp>();
1378 if (!module)
1379 return emitOpError("expected to belong to a module");
1380
1381 if (!module->getAttrOfType<UnitAttr>(
1382 GPUDialect::getContainerModuleAttrName()))
1383 return emitOpError("expected the closest surrounding module to have the '" +
1384 GPUDialect::getContainerModuleAttrName() +
1385 "' attribute");
1386
1387 if (hasClusterSize()) {
1388 if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1389 getClusterSizeZ().getType() != getClusterSizeX().getType())
1390 return emitOpError()
1391 << "expects types of the cluster dimensions must be the same";
1392 }
1393
1394 return success();
1395}
1396
1397static ParseResult
1399 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1400 Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1401 if (succeeded(parser.parseOptionalColon())) {
1402 if (parser.parseType(dimTy))
1403 return failure();
1404 } else {
1405 dimTy = IndexType::get(parser.getContext());
1406 }
1407 if (clusterValue.has_value()) {
1408 clusterXTy = clusterYTy = clusterZTy = dimTy;
1409 }
1410 return success();
1411}
1412
1413static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1414 Value clusterValue, Type clusterXTy,
1415 Type clusterYTy, Type clusterZTy) {
1416 if (!dimTy.isIndex())
1417 printer << ": " << dimTy;
1418}
1419
1420static ParseResult parseLaunchFuncOperands(
1421 OpAsmParser &parser,
1423 SmallVectorImpl<Type> &argTypes) {
1424 if (parser.parseOptionalKeyword("args"))
1425 return success();
1426
1427 auto parseElement = [&]() -> ParseResult {
1428 return failure(parser.parseOperand(argNames.emplace_back()) ||
1429 parser.parseColonType(argTypes.emplace_back()));
1430 };
1431
1433 parseElement, " in argument list");
1434}
1435
1437 OperandRange operands, TypeRange types) {
1438 if (operands.empty())
1439 return;
1440 printer << "args(";
1441 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1442 [&](const auto &pair) {
1443 auto [operand, type] = pair;
1444 printer << operand << " : " << type;
1445 });
1446 printer << ")";
1447}
1448
1449//===----------------------------------------------------------------------===//
1450// ShuffleOp
1451//===----------------------------------------------------------------------===//
1452
1453void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1454 int32_t offset, int32_t width, ShuffleMode mode) {
1455 build(builder, result, value,
1456 arith::ConstantOp::create(builder, result.location,
1457 builder.getI32IntegerAttr(offset)),
1458 arith::ConstantOp::create(builder, result.location,
1459 builder.getI32IntegerAttr(width)),
1460 mode);
1461}
1462
1463//===----------------------------------------------------------------------===//
1464// RotateOp
1465//===----------------------------------------------------------------------===//
1466
1467LogicalResult RotateOp::verify() {
1468 uint32_t offset = getOffset();
1469 uint32_t width = getWidth();
1470
1471 if (offset >= width) {
1472 return emitOpError() << "offset must be in the range [0, " << width << ")";
1473 }
1474
1475 return success();
1476}
1477
1478//===----------------------------------------------------------------------===//
1479// BarrierOp
1480//===----------------------------------------------------------------------===//
1481
1482/// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1483static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1484 PatternRewriter &rewriter) {
1485 auto nextOp = dyn_cast_or_null<BarrierOp>(op->getNextNode());
1486 if (!nextOp)
1487 return failure();
1488
1489 std::optional<ArrayAttr> thisMemfence = op.getAddressSpaces();
1490 std::optional<ArrayAttr> nextMemfence = nextOp.getAddressSpaces();
1491
1492 if (thisMemfence) {
1493 rewriter.modifyOpInPlace(op, [&]() {
1494 if (!nextMemfence) {
1495 op.removeAddressSpacesAttr();
1496 return;
1497 }
1498 // Fast path - merge where the two barriers fence the same spaces.
1499 if (*thisMemfence == *nextMemfence) {
1500 return;
1501 }
1502
1503 llvm::SmallSetVector<Attribute, 4> mergedSpaces;
1504 for (Attribute attr : *thisMemfence)
1505 mergedSpaces.insert(attr);
1506 for (Attribute attr : *nextMemfence)
1507 mergedSpaces.insert(attr);
1508 op.setAddressSpacesAttr(rewriter.getArrayAttr(mergedSpaces.takeVector()));
1509 });
1510 }
1511
1512 rewriter.eraseOp(nextOp);
1513 return success();
1514}
1515
1516void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1517 MLIRContext *context) {
1519}
1520
1521void BarrierOp::build(mlir::OpBuilder &odsBuilder,
1522 mlir::OperationState &odsState,
1523 std::optional<AddressSpace> addressSpace) {
1524 ArrayAttr addressSpacesAttr;
1525 if (addressSpace)
1526 addressSpacesAttr = odsBuilder.getArrayAttr(
1527 AddressSpaceAttr::get(odsBuilder.getContext(), addressSpace.value()));
1528 build(odsBuilder, odsState, addressSpacesAttr);
1529}
1530
1531/// Builds a barrier that causes memory operations affecting `memrefToFence` to
1532/// be completed after the barrier is concluded. Currently, this means setting
1533/// the fenced address spaces to those of the given memref if it is a gpu
1534/// address space.
1535void BarrierOp::build(OpBuilder &builder, OperationState &odsState,
1536 Value memrefToFence) {
1537 std::optional<AddressSpace> addrSpaceToFence;
1538 if (auto memrefType = dyn_cast<BaseMemRefType>(memrefToFence.getType()))
1539 if (auto addrSpaceAttr = dyn_cast_if_present<gpu::AddressSpaceAttr>(
1540 memrefType.getMemorySpace()))
1541 addrSpaceToFence = addrSpaceAttr.getValue();
1542 return build(builder, odsState, addrSpaceToFence);
1543}
1544
1545//===----------------------------------------------------------------------===//
1546// GPUFuncOp
1547//===----------------------------------------------------------------------===//
1548
1549/// Adds a new block argument that corresponds to buffers located in
1550/// workgroup memory.
1551BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1552 auto attrName = getNumWorkgroupAttributionsAttrName();
1553 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1554 (*this)->setAttr(attrName,
1555 IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1556 return getBody().insertArgument(
1557 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1558}
1559
1560/// Adds a new block argument that corresponds to buffers located in
1561/// private memory.
1562BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1563 // Buffers on the private memory always come after buffers on the workgroup
1564 // memory.
1565 return getBody().addArgument(type, loc);
1566}
1567
1568void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1569 StringRef name, FunctionType type,
1570 TypeRange workgroupAttributions,
1571 TypeRange privateAttributions,
1572 ArrayRef<NamedAttribute> attrs) {
1573 OpBuilder::InsertionGuard g(builder);
1574
1576 builder.getStringAttr(name));
1577 result.addAttribute(getFunctionTypeAttrName(result.name),
1578 TypeAttr::get(type));
1579 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1580 builder.getI64IntegerAttr(workgroupAttributions.size()));
1581 result.addAttributes(attrs);
1582 Region *body = result.addRegion();
1583 Block *entryBlock = builder.createBlock(body);
1584
1585 // TODO: Allow passing in proper locations here.
1586 for (Type argTy : type.getInputs())
1587 entryBlock->addArgument(argTy, result.location);
1588 for (Type argTy : workgroupAttributions)
1589 entryBlock->addArgument(argTy, result.location);
1590 for (Type argTy : privateAttributions)
1591 entryBlock->addArgument(argTy, result.location);
1592}
1593
1594/// Parses a GPU function memory attribution.
1595///
1596/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1597/// (`private` `(` ssa-id-and-type-list `)`)?
1598///
1599/// Note that this function parses only one of the two similar parts, with the
1600/// keyword provided as argument.
1601static ParseResult
1602parseAttributions(OpAsmParser &parser, StringRef keyword,
1604 Attribute &attributionAttrs) {
1605 // If we could not parse the keyword, just assume empty list and succeed.
1606 if (failed(parser.parseOptionalKeyword(keyword)))
1607 return success();
1608
1609 size_t existingArgs = args.size();
1610 ParseResult result =
1612 /*allowType=*/true, /*allowAttrs=*/true);
1613 if (failed(result))
1614 return result;
1615
1616 bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1617 [](const OpAsmParser::Argument &arg) -> bool {
1618 return arg.attrs && !arg.attrs.empty();
1619 });
1620 if (!hadAttrs) {
1621 attributionAttrs = nullptr;
1622 return result;
1623 }
1624
1625 Builder &builder = parser.getBuilder();
1626 SmallVector<Attribute> attributionAttrsVec;
1627 for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1628 if (!argument.attrs)
1629 attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1630 else
1631 attributionAttrsVec.push_back(argument.attrs);
1632 }
1633 attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1634 return result;
1635}
1636
1637/// Parses a GPU function.
1638///
1639/// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1640/// (`->` function-result-list)? memory-attribution `kernel`?
1641/// function-attributes? region
1642ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
1643 SmallVector<OpAsmParser::Argument> entryArgs;
1644 SmallVector<DictionaryAttr> resultAttrs;
1645 SmallVector<Type> resultTypes;
1646 bool isVariadic;
1647
1648 // Parse the function name.
1649 StringAttr nameAttr;
1651 result.attributes))
1652 return failure();
1653
1654 auto signatureLocation = parser.getCurrentLocation();
1656 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1657 resultAttrs)))
1658 return failure();
1659
1660 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1661 return parser.emitError(signatureLocation)
1662 << "gpu.func requires named arguments";
1663
1664 // Construct the function type. More types will be added to the region, but
1665 // not to the function type.
1666 Builder &builder = parser.getBuilder();
1667
1668 SmallVector<Type> argTypes;
1669 for (auto &arg : entryArgs)
1670 argTypes.push_back(arg.type);
1671 auto type = builder.getFunctionType(argTypes, resultTypes);
1672 result.addAttribute(getFunctionTypeAttrName(result.name),
1673 TypeAttr::get(type));
1674
1676 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1677 getResAttrsAttrName(result.name));
1678
1679 Attribute workgroupAttributionAttrs;
1680 // Parse workgroup memory attributions.
1681 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1682 entryArgs, workgroupAttributionAttrs)))
1683 return failure();
1684
1685 // Store the number of operands we just parsed as the number of workgroup
1686 // memory attributions.
1687 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1688 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1689 builder.getI64IntegerAttr(numWorkgroupAttrs));
1690 if (workgroupAttributionAttrs)
1691 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1692 workgroupAttributionAttrs);
1693
1694 Attribute privateAttributionAttrs;
1695 // Parse private memory attributions.
1696 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1697 entryArgs, privateAttributionAttrs)))
1698 return failure();
1699 if (privateAttributionAttrs)
1700 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1701 privateAttributionAttrs);
1702
1703 // Parse the kernel attribute if present.
1704 if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1705 result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1706 builder.getUnitAttr());
1707
1708 // Parse attributes.
1709 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1710 return failure();
1711
1712 // Parse the region. If no argument names were provided, take all names
1713 // (including those of attributions) from the entry block.
1714 auto *body = result.addRegion();
1715 return parser.parseRegion(*body, entryArgs);
1716}
1717
1718void GPUFuncOp::print(OpAsmPrinter &p) {
1719 p << ' ';
1720 p.printSymbolName(getName());
1721
1722 FunctionType type = getFunctionType();
1723 function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1724 /*isVariadic=*/false,
1725 type.getResults());
1726
1727 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1728 getWorkgroupAttribAttrs().value_or(nullptr));
1729 printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1730 getPrivateAttribAttrs().value_or(nullptr));
1731 if (isKernel())
1732 p << ' ' << getKernelKeyword();
1733
1735 p, *this,
1736 {getNumWorkgroupAttributionsAttrName(),
1737 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1738 getArgAttrsAttrName(), getResAttrsAttrName(),
1739 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1740 p << ' ';
1741 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1742}
1743
1744static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1745 StringAttr attrName) {
1746 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1747 if (!allAttrs || index >= allAttrs.size())
1748 return DictionaryAttr();
1749 return llvm::cast<DictionaryAttr>(allAttrs[index]);
1750}
1751
1752DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1753 return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1754}
1755
1756DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1757 return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1758}
1759
1760static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1761 DictionaryAttr value, StringAttr attrName) {
1762 MLIRContext *ctx = op.getContext();
1763 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1764 SmallVector<Attribute> elements;
1765 if (allAttrs)
1766 elements.append(allAttrs.begin(), allAttrs.end());
1767 while (elements.size() <= index)
1768 elements.push_back(DictionaryAttr::get(ctx));
1769 if (!value)
1770 elements[index] = DictionaryAttr::get(ctx);
1771 else
1772 elements[index] = value;
1773 ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1774 op->setAttr(attrName, newValue);
1775}
1776
1777void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1778 DictionaryAttr value) {
1779 setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1780}
1781
1782void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1783 DictionaryAttr value) {
1784 setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1785}
1786
1787static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1788 StringAttr name, StringAttr attrsName) {
1789 DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1790 if (!dict)
1791 return Attribute();
1792 return dict.get(name);
1793}
1794
1795Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1796 StringAttr name) {
1797 assert(index < getNumWorkgroupAttributions() &&
1798 "index must map to a workgroup attribution");
1799 return getAttributionAttr(*this, index, name,
1800 getWorkgroupAttribAttrsAttrName());
1801}
1802
1803Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1804 StringAttr name) {
1805 assert(index < getNumPrivateAttributions() &&
1806 "index must map to a private attribution");
1807 return getAttributionAttr(*this, index, name,
1808 getPrivateAttribAttrsAttrName());
1809}
1810
1811static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1812 Attribute value, StringAttr attrsName) {
1813 MLIRContext *ctx = op.getContext();
1815 DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1816 if (oldDict)
1817 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1818
1819 bool found = false;
1820 bool mustSort = true;
1821 for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1822 if (elems[i].getName() == name) {
1823 found = true;
1824 if (!value) {
1825 std::swap(elems[i], elems[elems.size() - 1]);
1826 elems.pop_back();
1827 } else {
1828 mustSort = false;
1829 elems[i] = NamedAttribute(elems[i].getName(), value);
1830 }
1831 break;
1832 }
1833 }
1834 if (!found) {
1835 if (!value)
1836 return;
1837 elems.emplace_back(name, value);
1838 }
1839 if (mustSort) {
1840 DictionaryAttr::sortInPlace(elems);
1841 }
1842 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1843 setAttributionAttrs(op, index, newDict, attrsName);
1844}
1845
1846void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1847 Attribute value) {
1848 assert(index < getNumWorkgroupAttributions() &&
1849 "index must map to a workgroup attribution");
1850 setAttributionAttr(*this, index, name, value,
1851 getWorkgroupAttribAttrsAttrName());
1852}
1853
1854void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1855 Attribute value) {
1856 assert(index < getNumPrivateAttributions() &&
1857 "index must map to a private attribution");
1858 setAttributionAttr(*this, index, name, value,
1859 getPrivateAttribAttrsAttrName());
1860}
1861
1862LogicalResult GPUFuncOp::verifyType() {
1863 if (isKernel() && getFunctionType().getNumResults() != 0)
1864 return emitOpError() << "expected void return type for kernel function";
1865
1866 return success();
1867}
1868
1869/// Verifies the body of the function.
1870LogicalResult GPUFuncOp::verifyBody() {
1871 if (empty())
1872 return emitOpError() << "expected body with at least one block";
1873 unsigned numFuncArguments = getNumArguments();
1874 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1875 unsigned numBlockArguments = front().getNumArguments();
1876 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1877 return emitOpError() << "expected at least "
1878 << numFuncArguments + numWorkgroupAttributions
1879 << " arguments to body region";
1880
1881 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1882 for (unsigned i = 0; i < numFuncArguments; ++i) {
1883 Type blockArgType = front().getArgument(i).getType();
1884 if (funcArgTypes[i] != blockArgType)
1885 return emitOpError() << "expected body region argument #" << i
1886 << " to be of type " << funcArgTypes[i] << ", got "
1887 << blockArgType;
1888 }
1889
1890 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1891 GPUDialect::getWorkgroupAddressSpace())) ||
1892 failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1893 GPUDialect::getPrivateAddressSpace())))
1894 return failure();
1895
1896 return success();
1897}
1898
1899//===----------------------------------------------------------------------===//
1900// ReturnOp
1901//===----------------------------------------------------------------------===//
1902
1903LogicalResult gpu::ReturnOp::verify() {
1904 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1905
1906 FunctionType funType = function.getFunctionType();
1907
1908 if (funType.getNumResults() != getOperands().size())
1909 return emitOpError()
1910 .append("expected ", funType.getNumResults(), " result operands")
1911 .attachNote(function.getLoc())
1912 .append("return type declared here");
1913
1914 for (const auto &pair : llvm::enumerate(
1915 llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1916 auto [type, operand] = pair.value();
1917 if (type != operand.getType())
1918 return emitOpError() << "unexpected type `" << operand.getType()
1919 << "' for operand #" << pair.index();
1920 }
1921 return success();
1922}
1923
1924//===----------------------------------------------------------------------===//
1925// GPUModuleOp
1926//===----------------------------------------------------------------------===//
1927
1928void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1929 StringRef name, ArrayAttr targets,
1930 Attribute offloadingHandler) {
1931 result.addRegion()->emplaceBlock();
1932 Properties &props = result.getOrAddProperties<Properties>();
1933 if (targets)
1934 props.targets = targets;
1935 props.setSymName(builder.getStringAttr(name));
1936 props.offloadingHandler = offloadingHandler;
1937}
1938
1939void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1940 StringRef name, ArrayRef<Attribute> targets,
1941 Attribute offloadingHandler) {
1942 build(builder, result, name,
1943 targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1944 offloadingHandler);
1945}
1946
1947bool GPUModuleOp::hasTarget(Attribute target) {
1948 if (ArrayAttr targets = getTargetsAttr())
1949 return llvm::count(targets.getValue(), target);
1950 return false;
1951}
1952
1953void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1954 ArrayAttr &targetsAttr = getProperties().targets;
1955 SmallVector<Attribute> targetsVector(targets);
1956 targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1957}
1958
1959LogicalResult GPUModuleOp::verify() {
1960 auto targets = getOperation()->getAttrOfType<ArrayAttr>("targets");
1961
1962 if (!targets)
1963 return success();
1964
1965 for (auto target : targets) {
1966 if (auto verifyTargetAttr =
1967 llvm::dyn_cast<TargetAttrVerifyInterface>(target)) {
1968 if (verifyTargetAttr.verifyTarget(getOperation()).failed())
1969 return failure();
1970 }
1971 }
1972 return success();
1973}
1974
1975//===----------------------------------------------------------------------===//
1976// GPUBinaryOp
1977//===----------------------------------------------------------------------===//
1978void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1979 Attribute offloadingHandler, ArrayAttr objects) {
1980 auto &properties = result.getOrAddProperties<Properties>();
1981 result.attributes.push_back(builder.getNamedAttr(
1983 properties.objects = objects;
1984 if (offloadingHandler)
1985 properties.offloadingHandler = offloadingHandler;
1986 else
1987 properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
1988}
1989
1990void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1991 Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1992 build(builder, result, name, offloadingHandler,
1993 objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
1994}
1995
1996static ParseResult parseOffloadingHandler(OpAsmParser &parser,
1997 Attribute &offloadingHandler) {
1998 if (succeeded(parser.parseOptionalLess())) {
1999 if (parser.parseAttribute(offloadingHandler))
2000 return failure();
2001 if (parser.parseGreater())
2002 return failure();
2003 }
2004 if (!offloadingHandler)
2005 offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
2006 return success();
2007}
2008
2010 Attribute offloadingHandler) {
2011 if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
2012 printer << '<' << offloadingHandler << '>';
2013}
2014
2015//===----------------------------------------------------------------------===//
2016// GPUMemcpyOp
2017//===----------------------------------------------------------------------===//
2018
2019LogicalResult MemcpyOp::verify() {
2020 auto srcType = getSrc().getType();
2021 auto dstType = getDst().getType();
2022
2023 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
2024 return emitOpError("arguments have incompatible element type");
2025
2026 if (failed(verifyCompatibleShape(srcType, dstType)))
2027 return emitOpError("arguments have incompatible shape");
2028
2029 return success();
2030}
2031
2032namespace {
2033
2034/// Erases a common case of copy ops where a destination value is used only by
2035/// the copy op, alloc and dealloc ops.
2036struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
2037 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
2038
2039 LogicalResult matchAndRewrite(MemcpyOp op,
2040 PatternRewriter &rewriter) const override {
2041 Value dest = op.getDst();
2042 Operation *destDefOp = dest.getDefiningOp();
2043 // `dest` must be defined by an op having Allocate memory effect in order to
2044 // perform the folding.
2045 if (!destDefOp ||
2047 return failure();
2048 // We can erase `op` iff `dest` has no other use apart from its
2049 // use by `op` and dealloc ops.
2050 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
2051 return user != op &&
2052 !hasSingleEffect<MemoryEffects::Free>(user, dest);
2053 }))
2054 return failure();
2055 // We can perform the folding if and only if op has a single async
2056 // dependency and produces an async token as result, or if it does not have
2057 // any async dependency and does not produce any async token result.
2058 if (op.getAsyncDependencies().size() > 1 ||
2059 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2060 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2061 return failure();
2062 rewriter.replaceOp(op, op.getAsyncDependencies());
2063 return success();
2064 }
2065};
2066
2067} // end anonymous namespace
2068
2069void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2070 MLIRContext *context) {
2071 results.add<EraseTrivialCopyOp>(context);
2072}
2073
2074//===----------------------------------------------------------------------===//
2075// GPU_SubgroupMmaLoadMatrixOp
2076//===----------------------------------------------------------------------===//
2077
2078LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2079 auto srcType = getSrcMemref().getType();
2080 auto resType = getRes().getType();
2081 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2082 auto operand = resMatrixType.getOperand();
2083 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2084
2085 if (!srcMemrefType.isLastDimUnitStride())
2086 return emitError(
2087 "expected source memref most minor dim must have unit stride");
2088
2089 if (operand != "AOp" && operand != "BOp" && operand != "COp")
2090 return emitError("only AOp, BOp and COp can be loaded");
2091
2092 return success();
2093}
2094
2095//===----------------------------------------------------------------------===//
2096// GPU_SubgroupMmaStoreMatrixOp
2097//===----------------------------------------------------------------------===//
2098
2099LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2100 auto srcType = getSrc().getType();
2101 auto dstType = getDstMemref().getType();
2102 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2103 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2104
2105 if (!dstMemrefType.isLastDimUnitStride())
2106 return emitError(
2107 "expected destination memref most minor dim must have unit stride");
2108
2109 if (srcMatrixType.getOperand() != "COp")
2110 return emitError(
2111 "expected the operand matrix being stored to have 'COp' operand type");
2112
2113 return success();
2114}
2115
2116//===----------------------------------------------------------------------===//
2117// GPU_SubgroupMmaComputeOp
2118//===----------------------------------------------------------------------===//
2119
2120LogicalResult SubgroupMmaComputeOp::verify() {
2121 enum OperandMap { A, B, C };
2122 SmallVector<MMAMatrixType, 3> opTypes;
2123 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
2124 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
2125 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
2126
2127 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
2128 opTypes[C].getOperand() != "COp")
2129 return emitError("operands must be in the order AOp, BOp, COp");
2130
2131 ArrayRef<int64_t> aShape, bShape, cShape;
2132 aShape = opTypes[A].getShape();
2133 bShape = opTypes[B].getShape();
2134 cShape = opTypes[C].getShape();
2135
2136 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2137 bShape[1] != cShape[1])
2138 return emitError("operand shapes do not satisfy matmul constraints");
2139
2140 return success();
2141}
2142
2143LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2144 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2145 return memref::foldMemRefCast(*this);
2146}
2147
2148LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2149 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2150 return memref::foldMemRefCast(*this);
2151}
2152
2153//===----------------------------------------------------------------------===//
2154// GPU_WaitOp
2155//===----------------------------------------------------------------------===//
2156
2157namespace {
2158
2159/// Remove gpu.wait op use of gpu.wait op def without async dependencies.
2160/// %t = gpu.wait async [] // No async dependencies.
2161/// ... gpu.wait ... [%t, ...] // %t can be removed.
2162struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2163public:
2165
2166 LogicalResult matchAndRewrite(WaitOp op,
2167 PatternRewriter &rewriter) const final {
2168 auto predicate = [](Value value) {
2169 auto waitOp = value.getDefiningOp<WaitOp>();
2170 return waitOp && waitOp->getNumOperands() == 0;
2171 };
2172 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2173 return failure();
2174 SmallVector<Value> validOperands;
2175 for (Value operand : op->getOperands()) {
2176 if (predicate(operand))
2177 continue;
2178 validOperands.push_back(operand);
2179 }
2180 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2181 return success();
2182 }
2183};
2184
2185/// Simplify trivial gpu.wait ops for the following patterns.
2186/// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2187/// dependencies).
2188/// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2189/// %t0.
2190/// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2191/// dependencies nor return any token.
2192struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2193public:
2195
2196 LogicalResult matchAndRewrite(WaitOp op,
2197 PatternRewriter &rewriter) const final {
2198 // Erase gpu.wait ops that neither have any async dependencies nor return
2199 // any async token.
2200 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2201 rewriter.eraseOp(op);
2202 return success();
2203 }
2204 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2205 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2206 op.getAsyncToken()) {
2207 rewriter.replaceOp(op, op.getAsyncDependencies());
2208 return success();
2209 }
2210 // Erase %t = gpu.wait async ... ops, where %t has no uses.
2211 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2212 rewriter.eraseOp(op);
2213 return success();
2214 }
2215 return failure();
2216 }
2217};
2218
2219} // end anonymous namespace
2220
2221void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2222 MLIRContext *context) {
2223 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2224}
2225
2226//===----------------------------------------------------------------------===//
2227// GPU_AllocOp
2228//===----------------------------------------------------------------------===//
2229
2230LogicalResult AllocOp::verify() {
2231 auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2232
2233 if (failed(verifyDynamicDimensionCount(getOperation(), memRefType,
2234 getDynamicSizes())))
2235 return failure();
2236
2237 unsigned numSymbols = 0;
2238 if (!memRefType.getLayout().isIdentity())
2239 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2240 if (getSymbolOperands().size() != numSymbols) {
2241 return emitOpError(
2242 "symbol operand count does not equal memref symbol count");
2243 }
2244
2245 return success();
2246}
2247
2248namespace {
2249
2250/// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2251/// `memref::AllocOp`.
2252struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2253 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2254
2255 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2256 PatternRewriter &rewriter) const override {
2257 std::optional<int64_t> index = dimOp.getConstantIndex();
2258 if (!index)
2259 return failure();
2260
2261 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2262 if (!memrefType || index.value() >= memrefType.getRank() ||
2263 !memrefType.isDynamicDim(index.value()))
2264 return failure();
2265
2266 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2267 if (!alloc)
2268 return failure();
2269
2270 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2271 memrefType.getDynamicDimIndex(index.value()));
2272 rewriter.replaceOp(dimOp, substituteOp);
2273 return success();
2274 }
2275};
2276
2277} // namespace
2278
2279void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2280 MLIRContext *context) {
2281 results.add<SimplifyDimOfAllocOp>(context);
2282}
2283
2284//===----------------------------------------------------------------------===//
2285// GPU object attribute
2286//===----------------------------------------------------------------------===//
2287
2288LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2289 Attribute target, CompilationTarget format,
2290 StringAttr object, DictionaryAttr properties,
2291 KernelTableAttr kernels) {
2292 if (!target)
2293 return emitError() << "the target attribute cannot be null";
2294 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2295 return success();
2296 return emitError() << "the target attribute must implement or promise the "
2297 "`gpu::TargetAttrInterface`";
2298}
2299
2300namespace {
2301ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2302 StringAttr &object) {
2303 std::optional<CompilationTarget> formatResult;
2304 StringRef enumKeyword;
2305 auto loc = odsParser.getCurrentLocation();
2306 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2307 formatResult = CompilationTarget::Fatbin;
2308 if (!formatResult &&
2309 (formatResult =
2310 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2311 odsParser.parseEqual())
2312 return odsParser.emitError(loc, "expected an equal sign");
2313 if (!formatResult)
2314 return odsParser.emitError(loc, "expected keyword for GPU object format");
2315 FailureOr<StringAttr> objectResult =
2316 FieldParser<StringAttr>::parse(odsParser);
2317 if (failed(objectResult))
2318 return odsParser.emitError(odsParser.getCurrentLocation(),
2319 "failed to parse GPU_ObjectAttr parameter "
2320 "'object' which is to be a `StringAttr`");
2321 format = *formatResult;
2322 object = *objectResult;
2323 return success();
2324}
2325
2326void printObject(AsmPrinter &odsParser, CompilationTarget format,
2327 StringAttr object) {
2328 if (format != CompilationTarget::Fatbin)
2329 odsParser << stringifyEnum(format) << " = ";
2330 odsParser << object;
2331}
2332} // namespace
2333
2334//===----------------------------------------------------------------------===//
2335// GPU select object attribute
2336//===----------------------------------------------------------------------===//
2337
2338LogicalResult
2339gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2340 Attribute target) {
2341 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2342 if (target) {
2343 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2344 if (intAttr.getInt() < 0) {
2345 return emitError() << "the object index must be positive";
2346 }
2347 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2348 return emitError()
2349 << "the target attribute must be a GPU Target attribute";
2350 }
2351 }
2352 return success();
2353}
2354
2355//===----------------------------------------------------------------------===//
2356// DynamicSharedMemoryOp
2357//===----------------------------------------------------------------------===//
2358
2359LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2360 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2361 return emitOpError() << "must be inside an op with symbol table";
2362
2363 MemRefType memrefType = getResultMemref().getType();
2364 // Check address space
2365 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2366 return emitOpError() << "address space must be "
2367 << gpu::AddressSpaceAttr::getMnemonic() << "<"
2368 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2369 }
2370 if (memrefType.hasStaticShape()) {
2371 return emitOpError() << "result memref type must be memref<?xi8, "
2372 "#gpu.address_space<workgroup>>";
2373 }
2374 return success();
2375}
2376
2377//===----------------------------------------------------------------------===//
2378// GPU WarpExecuteOnLane0Op
2379//===----------------------------------------------------------------------===//
2380
2381void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2382 p << "(" << getLaneid() << ")";
2383
2384 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2385 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2386 p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2387
2388 if (!getArgs().empty())
2389 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2390 if (!getResults().empty())
2391 p << " -> (" << getResults().getTypes() << ')';
2392 p << " ";
2393 p.printRegion(getRegion(),
2394 /*printEntryBlockArgs=*/true,
2395 /*printBlockTerminators=*/!getResults().empty());
2396 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2397}
2398
2399ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2400 OperationState &result) {
2401 // Create the region.
2402 result.regions.reserve(1);
2403 Region *warpRegion = result.addRegion();
2404
2405 auto &builder = parser.getBuilder();
2406 OpAsmParser::UnresolvedOperand laneId;
2407
2408 // Parse predicate operand.
2409 if (parser.parseLParen() ||
2410 parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2411 parser.parseRParen())
2412 return failure();
2413
2414 int64_t warpSize;
2415 if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2416 parser.parseRSquare())
2417 return failure();
2418 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2419 builder.getContext())),
2420 builder.getI64IntegerAttr(warpSize));
2421
2422 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2423 return failure();
2424
2425 llvm::SMLoc inputsOperandsLoc;
2426 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2427 SmallVector<Type> inputTypes;
2428 if (succeeded(parser.parseOptionalKeyword("args"))) {
2429 if (parser.parseLParen())
2430 return failure();
2431
2432 inputsOperandsLoc = parser.getCurrentLocation();
2433 if (parser.parseOperandList(inputsOperands) ||
2434 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2435 return failure();
2436 }
2437 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2438 result.operands))
2439 return failure();
2440
2441 // Parse optional results type list.
2442 if (parser.parseOptionalArrowTypeList(result.types))
2443 return failure();
2444 // Parse the region.
2445 if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2446 /*argTypes=*/{}))
2447 return failure();
2448 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2449
2450 // Parse the optional attribute list.
2451 if (parser.parseOptionalAttrDict(result.attributes))
2452 return failure();
2453 return success();
2454}
2455
2456void WarpExecuteOnLane0Op::getSuccessorRegions(
2457 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2458 if (!point.isParent()) {
2459 regions.push_back(RegionSuccessor::parent());
2460 return;
2461 }
2462
2463 // The warp region is always executed
2464 regions.push_back(RegionSuccessor(&getWarpRegion()));
2465}
2466
2467ValueRange WarpExecuteOnLane0Op::getSuccessorInputs(RegionSuccessor successor) {
2468 return successor.isParent() ? ValueRange(getResults()) : ValueRange();
2469}
2470void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2471 TypeRange resultTypes, Value laneId,
2472 int64_t warpSize) {
2473 build(builder, result, resultTypes, laneId, warpSize,
2474 /*operands=*/{}, /*argTypes=*/{});
2475}
2476
2477void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2478 TypeRange resultTypes, Value laneId,
2479 int64_t warpSize, ValueRange args,
2480 TypeRange blockArgTypes) {
2481 result.addOperands(laneId);
2482 result.addAttribute(getAttributeNames()[0],
2483 builder.getI64IntegerAttr(warpSize));
2484 result.addTypes(resultTypes);
2485 result.addOperands(args);
2486 assert(args.size() == blockArgTypes.size());
2487 OpBuilder::InsertionGuard guard(builder);
2488 Region *warpRegion = result.addRegion();
2489 Block *block = builder.createBlock(warpRegion);
2490 for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2491 block->addArgument(type, arg.getLoc());
2492}
2493
2494/// Helper check if the distributed vector type is consistent with the expanded
2495/// type and distributed size.
2496static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2497 int64_t warpSize, Operation *op) {
2498 // If the types matches there is no distribution.
2499 if (expanded == distributed)
2500 return success();
2501 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2502 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2503 if (!expandedVecType || !distributedVecType)
2504 return op->emitOpError("expected vector type for distributed operands.");
2505 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2506 expandedVecType.getElementType() != distributedVecType.getElementType())
2507 return op->emitOpError(
2508 "expected distributed vectors to have same rank and element type.");
2509
2510 SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2511 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2512 int64_t eDim = expandedVecType.getDimSize(i);
2513 int64_t dDim = distributedVecType.getDimSize(i);
2514 if (eDim == dDim)
2515 continue;
2516 if (eDim % dDim != 0)
2517 return op->emitOpError()
2518 << "expected expanded vector dimension #" << i << " (" << eDim
2519 << ") to be a multipler of the distributed vector dimension ("
2520 << dDim << ")";
2521 scales[i] = eDim / dDim;
2522 }
2523 if (llvm::product_of(scales) != warpSize)
2524 return op->emitOpError()
2525 << "incompatible distribution dimensions from " << expandedVecType
2526 << " to " << distributedVecType << " with warp size = " << warpSize;
2527
2528 return success();
2529}
2530
2531LogicalResult WarpExecuteOnLane0Op::verify() {
2532 if (getArgs().size() != getWarpRegion().getNumArguments())
2533 return emitOpError(
2534 "expected same number op arguments and block arguments.");
2535 gpu::YieldOp yield = getTerminator();
2536 if (yield.getNumOperands() != getNumResults())
2537 return emitOpError(
2538 "expected same number of yield operands and return values.");
2539 int64_t warpSize = getWarpSize();
2540 for (auto [regionArg, arg] :
2541 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2542 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2543 warpSize, getOperation())))
2544 return failure();
2545 }
2546 for (auto [yieldOperand, result] :
2547 llvm::zip_equal(yield.getOperands(), getResults())) {
2548 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2549 warpSize, getOperation())))
2550 return failure();
2551 }
2552 return success();
2553}
2554bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2555 return succeeded(
2556 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2557}
2558
2559gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2560 return cast<gpu::YieldOp>(getBody()->getTerminator());
2561}
2562
2563//===----------------------------------------------------------------------===//
2564// GPU_SubgroupBroadcastOp
2565//===----------------------------------------------------------------------===//
2566
2567void gpu::SubgroupBroadcastOp::inferResultRanges(
2568 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
2569 setResultRange(getResult(), argRanges.front());
2570}
2571
2572Speculation::Speculatability gpu::SubgroupBroadcastOp::getSpeculatability() {
2573 switch (getBroadcastType()) {
2574 case BroadcastType::first_active_lane:
2575 // Cannot speculate first_lane broadcast, because speculating it across
2576 // control flow can change the active lanes.
2578 case BroadcastType::specific_lane:
2579 // Speculation should be safe as long as we inside structured control flow.
2581 }
2582 llvm_unreachable("Unknown BroadcastType");
2583}
2584
2585LogicalResult gpu::SubgroupBroadcastOp::verify() {
2586 switch (getBroadcastType()) {
2587 case BroadcastType::first_active_lane:
2588 if (getLane())
2589 return emitOpError()
2590 << "lane can only be specified for `specific_lane` broadcast";
2591 return success();
2592 case BroadcastType::specific_lane:
2593 if (!getLane())
2594 return emitOpError()
2595 << "lane must be specified for `specific_lane` broadcast";
2596 return success();
2597 }
2598 llvm_unreachable("Unknown BroadcastType");
2599}
2600
2601OpFoldResult gpu::SubgroupBroadcastOp::fold(FoldAdaptor /*adaptor*/) {
2602 // Broadcast result is always uniform.
2603 if (auto prev = getSrc().getDefiningOp<SubgroupBroadcastOp>())
2604 return prev.getResult();
2605
2606 return nullptr;
2607}
2608
2609//===----------------------------------------------------------------------===//
2610// GPU KernelMetadataAttr
2611//===----------------------------------------------------------------------===//
2612
2613KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2614 DictionaryAttr metadata) {
2615 assert(kernel && "invalid kernel");
2616 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2617 kernel.getAllArgAttrs(), metadata);
2618}
2619
2620KernelMetadataAttr
2621KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2622 FunctionOpInterface kernel,
2623 DictionaryAttr metadata) {
2624 assert(kernel && "invalid kernel");
2625 return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2626 kernel.getAllArgAttrs(), metadata);
2627}
2628
2629KernelMetadataAttr
2630KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2631 if (attrs.empty())
2632 return *this;
2633 NamedAttrList attrList;
2634 if (DictionaryAttr dict = getMetadata())
2635 attrList.append(dict);
2636 attrList.append(attrs);
2637 return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2638 attrList.getDictionary(getContext()));
2639}
2640
2641LogicalResult
2642KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2643 StringAttr name, Type functionType,
2644 ArrayAttr argAttrs, DictionaryAttr metadata) {
2645 if (name.empty())
2646 return emitError() << "the kernel name can't be empty";
2647 if (argAttrs) {
2648 if (llvm::any_of(argAttrs, [](Attribute attr) {
2649 return !llvm::isa<DictionaryAttr>(attr);
2650 }))
2651 return emitError()
2652 << "all attributes in the array must be a dictionary attribute";
2653 }
2654 return success();
2655}
2656
2657//===----------------------------------------------------------------------===//
2658// GPU KernelTableAttr
2659//===----------------------------------------------------------------------===//
2660
2661KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2662 ArrayRef<KernelMetadataAttr> kernels,
2663 bool isSorted) {
2664 // Note that `is_sorted` is always only invoked once even with assertions ON.
2665 assert((!isSorted || llvm::is_sorted(kernels)) &&
2666 "expected a sorted kernel array");
2667 // Immediately return the attribute if the array is sorted.
2668 if (isSorted || llvm::is_sorted(kernels))
2669 return Base::get(context, kernels);
2670 // Sort the array.
2671 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2672 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2673 return Base::get(context, kernelsTmp);
2674}
2675
2676KernelTableAttr KernelTableAttr::getChecked(
2677 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2678 ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2679 // Note that `is_sorted` is always only invoked once even with assertions ON.
2680 assert((!isSorted || llvm::is_sorted(kernels)) &&
2681 "expected a sorted kernel array");
2682 // Immediately return the attribute if the array is sorted.
2683 if (isSorted || llvm::is_sorted(kernels))
2684 return Base::getChecked(emitError, context, kernels);
2685 // Sort the array.
2686 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2687 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2688 return Base::getChecked(emitError, context, kernelsTmp);
2689}
2690
2691LogicalResult
2692KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2693 ArrayRef<KernelMetadataAttr> kernels) {
2694 if (kernels.size() < 2)
2695 return success();
2696 // Check that the kernels are uniquely named.
2697 if (std::adjacent_find(kernels.begin(), kernels.end(),
2698 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2699 return l.getName() == r.getName();
2700 }) != kernels.end()) {
2701 return emitError() << "expected all kernels to be uniquely named";
2702 }
2703 return success();
2704}
2705
2706KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2707 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2708 return found ? *iterator : KernelMetadataAttr();
2709}
2710
2711KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2712 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2713 return found ? *iterator : KernelMetadataAttr();
2714}
2715
2716//===----------------------------------------------------------------------===//
2717// GPU target options
2718//===----------------------------------------------------------------------===//
2719
2734
2752
2753TypeID TargetOptions::getTypeID() const { return typeID; }
2754
2755StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2756
2760
2761StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2762
2763StringRef TargetOptions::getELFSection() const { return elfSection; }
2764
2768
2769function_ref<void(llvm::Module &)>
2773
2774function_ref<void(llvm::Module &)>
2778
2779function_ref<void(llvm::Module &)>
2783
2785 return isaCallback;
2786}
2787
2788CompilationTarget TargetOptions::getCompilationTarget() const {
2789 return compilationTarget;
2790}
2791
2793 return CompilationTarget::Fatbin;
2794}
2795
2796std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2798 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2799 llvm::StringSaver stringSaver(options.first);
2800 StringRef opts = cmdOptions;
2801 // For a correct tokenization of the command line options `opts` must be
2802 // unquoted, otherwise the tokenization function returns a single string: the
2803 // unquoted `cmdOptions` -which is not the desired behavior.
2804 // Remove any quotes if they are at the beginning and end of the string:
2805 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2806 opts.consume_front("\""), opts.consume_back("\"");
2807 if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2808 opts.consume_front("'"), opts.consume_back("'");
2809#ifdef _WIN32
2810 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2811 /*MarkEOLs=*/false);
2812#else
2813 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2814 /*MarkEOLs=*/false);
2815#endif // _WIN32
2816 return options;
2817}
2818
2819std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2823
2824std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2826 size_t startPos = cmdOptions.find(startsWith);
2827 if (startPos == std::string::npos)
2828 return {llvm::BumpPtrAllocator(), SmallVector<const char *>()};
2829
2830 auto tokenized =
2831 tokenizeCmdOptions(cmdOptions.substr(startPos + startsWith.size()));
2832 cmdOptions.resize(startPos);
2833 return tokenized;
2834}
2835
2837
2838#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2839#include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2840
2841#define GET_ATTRDEF_CLASSES
2842#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2843
2844#define GET_OP_CLASSES
2845#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2846
2847#include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values, ArrayAttr attributes={})
static LogicalResult verifyDistributedType(Type expanded, Type distributed, int64_t warpSize, Operation *op)
Helper check if the distributed vector type is consistent with the expanded type and distributed size...
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices, StringRef keyword)
static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op, PatternRewriter &rewriter)
Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition TypeID.h:323
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
UnitAttr getUnitAttr()
Definition Builders.cpp:102
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:80
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition Builders.cpp:108
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:98
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool isParent() const
Returns true if branching from the parent op.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:78
bool isIndex() const
Definition Types.cpp:56
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
function_ref< void(llvm::Module &)> optimizedLlvmIRCallback
Callback invoked with LLVM IR for the device module after LLVM optimizations but before codegen.
function_ref< void(StringRef)> getISACallback() const
Returns the callback invoked with the target ISA for the device, for example PTX assembly.
TypeID getTypeID() const
Returns the typeID.
std::string toolkitPath
Path to the target toolkit.
SymbolTable * getSymbolTable() const
Returns the result of the getSymbolTableCallback callback or a nullptr if no callback was provided.
StringRef getELFSection() const
Returns the ELF section.
StringRef getCmdOptions() const
Returns the command line options.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
function_ref< void(StringRef)> isaCallback
Callback invoked with the target ISA for the device, for example PTX assembly.
CompilationTarget compilationTarget
Compilation process target format.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
function_ref< void(llvm::Module &)> initialLlvmIRCallback
Callback invoked with the initial LLVM IR for the device module.
ArrayRef< Attribute > getLibrariesToLink() const
Returns the LLVM libraries to link to.
TargetOptions(StringRef toolkitPath={}, ArrayRef< Attribute > librariesToLink={}, StringRef cmdOptions={}, StringRef elfSection={}, CompilationTarget compilationTarget=getDefaultCompilationTarget(), function_ref< SymbolTable *()> getSymbolTableCallback={}, function_ref< void(llvm::Module &)> initialLlvmIRCallback={}, function_ref< void(llvm::Module &)> linkedLlvmIRCallback={}, function_ref< void(llvm::Module &)> optimizedLlvmIRCallback={}, function_ref< void(StringRef)> isaCallback={})
Constructor initializing the toolkit path, the list of files to link to, extra command line options,...
function_ref< void(llvm::Module &)> getOptimizedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after LLVM optimizations but before c...
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
StringRef getToolkitPath() const
Returns the toolkit path.
SmallVector< Attribute > librariesToLink
List of files to link with the LLVM module.
function_ref< void(llvm::Module &)> linkedLlvmIRCallback
Callback invoked with LLVM IR for the device module after linking the device libraries.
function_ref< void(llvm::Module &)> getInitialLlvmIRCallback() const
Returns the callback invoked with the initial LLVM IR for the device module.
function_ref< SymbolTable *()> getSymbolTableCallback
Callback for obtaining the parent symbol table of all the GPU modules being serialized.
static CompilationTarget getDefaultCompilationTarget()
Returns the default compilation target: CompilationTarget::Fatbin.
function_ref< void(llvm::Module &)> getLinkedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after linking the device libraries.
std::string elfSection
ELF Section where the binary needs to be located.
CompilationTarget getCompilationTarget() const
Returns the compilation target.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
std::pair< IteratorT, bool > findAttrSorted(IteratorT first, IteratorT last, StringRef name)
Using llvm::lower_bound requires an extra string comparison to check whether the returned iterator po...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition GPUDialect.h:39