MLIR 23.0.0git
GPUDialect.cpp
Go to the documentation of this file.
1//===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the GPU kernel-related dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
20#include "mlir/IR/Attributes.h"
21#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinOps.h"
25#include "mlir/IR/Diagnostics.h"
27#include "mlir/IR/Matchers.h"
30#include "mlir/IR/SymbolTable.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include "llvm/Support/CommandLine.h"
39#include "llvm/Support/ErrorHandling.h"
40#include "llvm/Support/FormatVariadic.h"
41#include "llvm/Support/InterleavedRange.h"
42#include "llvm/Support/StringSaver.h"
43#include <cassert>
44#include <numeric>
45
46using namespace mlir;
47using namespace mlir::gpu;
48
49#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
50
51//===----------------------------------------------------------------------===//
52// GPU Device Mapping Attributes
53//===----------------------------------------------------------------------===//
54
55int64_t GPUBlockMappingAttr::getMappingId() const {
56 return static_cast<int64_t>(getBlock());
57}
58
59bool GPUBlockMappingAttr::isLinearMapping() const {
60 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
61}
62
63int64_t GPUBlockMappingAttr::getRelativeIndex() const {
64 return isLinearMapping()
65 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
66 : getMappingId();
67}
68
69int64_t GPUWarpgroupMappingAttr::getMappingId() const {
70 return static_cast<int64_t>(getWarpgroup());
71}
72
73bool GPUWarpgroupMappingAttr::isLinearMapping() const {
74 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
75}
76
77int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
78 return isLinearMapping()
79 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
80 : getMappingId();
81}
82
83int64_t GPUWarpMappingAttr::getMappingId() const {
84 return static_cast<int64_t>(getWarp());
85}
86
87bool GPUWarpMappingAttr::isLinearMapping() const {
88 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
89}
90
91int64_t GPUWarpMappingAttr::getRelativeIndex() const {
92 return isLinearMapping()
93 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
94 : getMappingId();
95}
96
97int64_t GPUThreadMappingAttr::getMappingId() const {
98 return static_cast<int64_t>(getThread());
99}
100
101bool GPUThreadMappingAttr::isLinearMapping() const {
102 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
103}
104
105int64_t GPUThreadMappingAttr::getRelativeIndex() const {
106 return isLinearMapping()
107 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
108 : getMappingId();
109}
110
111int64_t GPULaneMappingAttr::getMappingId() const {
112 return static_cast<int64_t>(getLane());
113}
114
115bool GPULaneMappingAttr::isLinearMapping() const {
116 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
117}
118
119int64_t GPULaneMappingAttr::getRelativeIndex() const {
120 return isLinearMapping()
121 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
122 : getMappingId();
123}
124
125int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds() const { return 64; }
126
127/// 8 4 0
128/// Example mask : 0 0 0 1 1 0 1 0 0
129///
130/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
131/// Logical id for e.g. 5 (2) constructs filter (1 << 5 - 1).
132///
133/// Example mask : 0 0 0 1 1 0 1 0 0
134/// Example filter: 0 0 0 0 1 1 1 1 1
135/// Intersection : 0 0 0 0 1 0 1 0 0
136/// PopCnt : 2
137Value GPUMappingMaskAttr::createLogicalLinearMappingId(
138 OpBuilder &b, Value physicalLinearMappingId) const {
139 Location loc = physicalLinearMappingId.getLoc();
140 Value mask =
141 arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
142 Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
143 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
144 filter = arith::SubIOp::create(b, loc, filter, one);
145 Value filteredId = arith::AndIOp::create(b, loc, mask, filter);
146 return math::CtPopOp::create(b, loc, filteredId);
147}
148
149/// 8 4 0
150/// Example mask : 0 0 0 1 1 0 1 0 0
151///
152/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
153/// Logical id for e.g. 5 (2) constructs filter (1 << 5).
154///
155/// Example mask : 0 0 0 1 1 0 1 0 0
156/// Example filter: 0 0 0 1 0 0 0 0 0
157/// Intersection : 0 0 0 1 0 0 0 0 0
158/// Cmp : 1
159Value GPUMappingMaskAttr::createIsActiveIdPredicate(
160 OpBuilder &b, Value physicalLinearMappingId) const {
161 Location loc = physicalLinearMappingId.getLoc();
162 Value mask =
163 arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
164 Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
165 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
166 Value filtered = arith::AndIOp::create(b, loc, mask, filter);
167 Value zero = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(0));
168 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ne, filtered,
169 zero);
170}
171
172int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
173 return static_cast<int64_t>(getAddressSpace());
174}
175
176bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
177 llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
178}
179
180int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
181 llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
182}
183
184//===----------------------------------------------------------------------===//
185// MMAMatrixType
186//===----------------------------------------------------------------------===//
187
189 StringRef operand) {
190 return Base::get(elementType.getContext(), shape, elementType, operand);
191}
192
195 ArrayRef<int64_t> shape, Type elementType,
196 StringRef operand) {
197 return Base::getChecked(emitError, elementType.getContext(), shape,
198 elementType, operand);
199}
200
201unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
202
204 return getImpl()->getShape();
205}
206
207Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
208
209StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
210
212 return elementType.isF16() || elementType.isF32() || elementType.isF64() ||
213 elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
214 elementType.isInteger(32);
215}
216
217LogicalResult
219 ArrayRef<int64_t> shape, Type elementType,
220 StringRef operand) {
221 if (operand != "AOp" && operand != "BOp" && operand != "COp")
222 return emitError() << "operand expected to be one of AOp, BOp or COp";
223
224 if (shape.size() != 2)
225 return emitError() << "MMAMatrixType must have exactly two dimensions";
226
227 if (!MMAMatrixType::isValidElementType(elementType))
228 return emitError()
229 << "MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
230
231 return success();
232}
233
234//===----------------------------------------------------------------------===//
235// GPUDialect
236//===----------------------------------------------------------------------===//
237
238bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
239 if (!memorySpace)
240 return false;
241 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
242 return gpuAttr.getValue() == getWorkgroupAddressSpace();
243 return false;
244}
245
246bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
247 Attribute memorySpace = type.getMemorySpace();
248 return isWorkgroupMemoryAddressSpace(memorySpace);
249}
250
251bool GPUDialect::isKernel(Operation *op) {
252 UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
253 return static_cast<bool>(isKernelAttr);
254}
255
256namespace {
257/// This class defines the interface for handling inlining with gpu
258/// operations.
259struct GPUInlinerInterface : public DialectInlinerInterface {
260 using DialectInlinerInterface::DialectInlinerInterface;
261
262 /// All gpu dialect ops can be inlined.
263 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
264 return true;
265 }
266};
267} // namespace
268
269void GPUDialect::initialize() {
270 addTypes<AsyncTokenType>();
271 addTypes<MMAMatrixType>();
272 addTypes<SparseDnTensorHandleType>();
273 addTypes<SparseSpMatHandleType>();
274 addTypes<SparseSpGEMMOpHandleType>();
275 addOperations<
276#define GET_OP_LIST
277#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
278 >();
279 addAttributes<
280#define GET_ATTRDEF_LIST
281#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
282 >();
283 addInterfaces<GPUInlinerInterface>();
284 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
285 TerminatorOp>();
286 declarePromisedInterfaces<
287 ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,
288 ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,
289 SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();
290}
291
292static std::string getSparseHandleKeyword(SparseHandleKind kind) {
293 switch (kind) {
295 return "sparse.dntensor_handle";
297 return "sparse.spmat_handle";
299 return "sparse.spgemmop_handle";
300 }
301 llvm_unreachable("unknown sparse handle kind");
302 return "";
303}
304
305Type GPUDialect::parseType(DialectAsmParser &parser) const {
306 // Parse the main keyword for the type.
307 StringRef keyword;
308 if (parser.parseKeyword(&keyword))
309 return Type();
310 MLIRContext *context = getContext();
311
312 // Handle 'async token' types.
313 if (keyword == "async.token")
314 return AsyncTokenType::get(context);
315
316 if (keyword == "mma_matrix") {
317 SMLoc beginLoc = parser.getNameLoc();
318
319 // Parse '<'.
320 if (parser.parseLess())
321 return nullptr;
322
323 // Parse the size and elementType.
324 SmallVector<int64_t> shape;
325 Type elementType;
326 if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
327 parser.parseType(elementType))
328 return nullptr;
329
330 // Parse ','
331 if (parser.parseComma())
332 return nullptr;
333
334 // Parse operand.
335 std::string operand;
336 if (failed(parser.parseOptionalString(&operand)))
337 return nullptr;
338
339 // Parse '>'.
340 if (parser.parseGreater())
341 return nullptr;
342
344 parser.getEncodedSourceLoc(beginLoc)),
345 shape, elementType, operand);
346 }
347
349 return SparseDnTensorHandleType::get(context);
351 return SparseSpMatHandleType::get(context);
353 return SparseSpGEMMOpHandleType::get(context);
354
355 parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
356 return Type();
357}
358// TODO: print refined type here. Notice that should be corresponding to the
359// parser
360void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
361 TypeSwitch<Type>(type)
362 .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
363 .Case<SparseDnTensorHandleType>([&](Type) {
365 })
366 .Case<SparseSpMatHandleType>(
368 .Case<SparseSpGEMMOpHandleType>([&](Type) {
370 })
371 .Case([&](MMAMatrixType fragTy) {
372 os << "mma_matrix<";
373 auto shape = fragTy.getShape();
374 for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
375 os << *dim << 'x';
376 os << shape.back() << 'x' << fragTy.getElementType();
377 os << ", \"" << fragTy.getOperand() << "\"" << '>';
378 })
379 .DefaultUnreachable("unexpected 'gpu' type kind");
380}
381
382static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
383 NamedAttribute attr) {
384 auto array = dyn_cast<DenseI32ArrayAttr>(attr.getValue());
385 if (!array)
386 return op->emitOpError(Twine(attr.getName()) +
387 " must be a dense i32 array");
388 if (array.size() != 3)
389 return op->emitOpError(Twine(attr.getName()) +
390 " must contain exactly 3 elements");
391 return success();
392}
393
394LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
395 NamedAttribute attr) {
396 if (attr.getName() == getKnownBlockSizeAttrHelper().getName())
397 return verifyKnownLaunchSizeAttr(op, attr);
398 if (attr.getName() == getKnownGridSizeAttrHelper().getName())
399 return verifyKnownLaunchSizeAttr(op, attr);
400 if (attr.getName() == getKnownClusterSizeAttrHelper().getName())
401 return verifyKnownLaunchSizeAttr(op, attr);
402 if (!llvm::isa<UnitAttr>(attr.getValue()) ||
403 attr.getName() != getContainerModuleAttrName())
404 return success();
405
406 auto module = dyn_cast<ModuleOp>(op);
407 if (!module)
408 return op->emitError("expected '")
409 << getContainerModuleAttrName() << "' attribute to be attached to '"
410 << ModuleOp::getOperationName() << '\'';
411
412 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
413 // Ignore launches that are nested more or less deep than functions in the
414 // module we are currently checking.
415 if (!launchOp->getParentOp() ||
416 launchOp->getParentOp()->getParentOp() != module)
417 return success();
418
419 // Ignore launch ops with missing attributes here. The errors will be
420 // reported by the verifiers of those ops.
421 if (!launchOp->getAttrOfType<SymbolRefAttr>(
422 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
423 return success();
424
425 // Check that `launch_func` refers to a well-formed GPU kernel container.
426 StringAttr kernelContainerName = launchOp.getKernelModuleName();
427 Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
428 if (!kernelContainer)
429 return launchOp.emitOpError()
430 << "kernel container '" << kernelContainerName.getValue()
431 << "' is undefined";
432
433 // If the container is a GPU binary op return success.
434 if (isa<BinaryOp>(kernelContainer))
435 return success();
436
437 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
438 if (!kernelModule)
439 return launchOp.emitOpError()
440 << "kernel module '" << kernelContainerName.getValue()
441 << "' is undefined";
442
443 // Check that `launch_func` refers to a well-formed kernel function.
444 Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
445 if (!kernelFunc)
446 return launchOp.emitOpError("kernel function '")
447 << launchOp.getKernel() << "' is undefined";
448 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
449 if (!kernelConvertedFunction) {
450 InFlightDiagnostic diag = launchOp.emitOpError()
451 << "referenced kernel '" << launchOp.getKernel()
452 << "' is not a function";
453 diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
454 return diag;
455 }
456
457 if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
458 GPUDialect::getKernelFuncAttrName()))
459 return launchOp.emitOpError("kernel function is missing the '")
460 << GPUDialect::getKernelFuncAttrName() << "' attribute";
461
462 // TODO: If the kernel isn't a GPU function (which happens during separate
463 // compilation), do not check type correspondence as it would require the
464 // verifier to be aware of the type conversion.
465 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
466 if (!kernelGPUFunction)
467 return success();
468
469 unsigned actualNumArguments = launchOp.getNumKernelOperands();
470 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
471 if (expectedNumArguments != actualNumArguments)
472 return launchOp.emitOpError("got ")
473 << actualNumArguments << " kernel operands but expected "
474 << expectedNumArguments;
475
476 auto functionType = kernelGPUFunction.getFunctionType();
477 for (unsigned i = 0; i < expectedNumArguments; ++i) {
478 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
479 return launchOp.emitOpError("type of function argument ")
480 << i << " does not match";
481 }
482 }
483
484 return success();
485 });
486
487 return walkResult.wasInterrupted() ? failure() : success();
488}
489
490/// Parses an optional list of async operands with an optional leading keyword.
491/// (`async`)? (`[` ssa-id-list `]`)?
492///
493/// This method is used by the tablegen assembly format for async ops as well.
494static ParseResult parseAsyncDependencies(
495 OpAsmParser &parser, Type &asyncTokenType,
497 auto loc = parser.getCurrentLocation();
498 if (succeeded(parser.parseOptionalKeyword("async"))) {
499 if (parser.getNumResults() == 0)
500 return parser.emitError(loc, "needs to be named when marked 'async'");
501 asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
502 }
503 return parser.parseOperandList(asyncDependencies,
505}
506
507/// Prints optional async dependencies with its leading keyword.
508/// (`async`)? (`[` ssa-id-list `]`)?
509// Used by the tablegen assembly format for several async ops.
511 Type asyncTokenType,
512 OperandRange asyncDependencies) {
513 if (asyncTokenType)
514 printer << "async";
515 if (asyncDependencies.empty())
516 return;
517 if (asyncTokenType)
518 printer << ' ';
519 printer << llvm::interleaved_array(asyncDependencies);
520}
521
522// GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
523/// Parses a GPU function memory attribution.
524///
525/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
526/// (`private` `(` ssa-id-and-type-list `)`)?
527///
528/// Note that this function parses only one of the two similar parts, with the
529/// keyword provided as argument.
530static ParseResult
531parseAttributions(OpAsmParser &parser, StringRef keyword,
533 // If we could not parse the keyword, just assume empty list and succeed.
534 if (failed(parser.parseOptionalKeyword(keyword)))
535 return success();
536
538 /*allowType=*/true);
539}
540
541static void printAttributions(OpAsmPrinter &p, StringRef keyword,
543 ArrayAttr attributes = {}) {
544 if (values.empty())
545 return;
546
547 p << ' ' << keyword << '(';
548 llvm::interleaveComma(
549 llvm::enumerate(values), p, [&p, attributes](auto pair) {
550 BlockArgument v = pair.value();
551 p << v << " : " << v.getType();
552
553 size_t attributionIndex = pair.index();
554 DictionaryAttr attrs;
555 if (attributes && attributionIndex < attributes.size())
556 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
557 if (attrs)
558 p.printOptionalAttrDict(attrs.getValue());
559 });
560 p << ')';
561}
562
563/// Verifies a GPU function memory attribution.
564static LogicalResult verifyAttributions(Operation *op,
565 ArrayRef<BlockArgument> attributions,
566 gpu::AddressSpace memorySpace) {
567 for (Value v : attributions) {
568 auto type = llvm::dyn_cast<MemRefType>(v.getType());
569 if (!type)
570 return op->emitOpError() << "expected memref type in attribution";
571
572 // We can only verify the address space if it hasn't already been lowered
573 // from the AddressSpaceAttr to a target-specific numeric value.
574 auto addressSpace =
575 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
576 if (!addressSpace)
577 continue;
578 if (addressSpace.getValue() != memorySpace)
579 return op->emitOpError()
580 << "expected memory space " << stringifyAddressSpace(memorySpace)
581 << " in attribution";
582 }
583 return success();
584}
585
586//===----------------------------------------------------------------------===//
587// AllReduceOp
588//===----------------------------------------------------------------------===//
589
590static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
591 Type resType) {
592 using Kind = gpu::AllReduceOperation;
593 if (llvm::is_contained(
594 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
595 opName)) {
596 if (!isa<FloatType>(resType))
597 return failure();
598 }
599
600 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
601 Kind::AND, Kind::OR, Kind::XOR},
602 opName)) {
603 if (!isa<IntegerType>(resType))
604 return failure();
605 }
606
607 return success();
608}
609
610LogicalResult gpu::AllReduceOp::verifyRegions() {
611 if (getBody().empty() != getOp().has_value())
612 return emitError("expected either an op attribute or a non-empty body");
613 if (!getBody().empty()) {
614 if (getBody().getNumArguments() != 2)
615 return emitError("expected two region arguments");
616 for (auto argument : getBody().getArguments()) {
617 if (argument.getType() != getType())
618 return emitError("incorrect region argument type");
619 }
620 unsigned yieldCount = 0;
621 for (Block &block : getBody()) {
622 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
623 if (yield.getNumOperands() != 1)
624 return emitError("expected one gpu.yield operand");
625 if (yield.getOperand(0).getType() != getType())
626 return emitError("incorrect gpu.yield type");
627 ++yieldCount;
628 }
629 }
630 if (yieldCount == 0)
631 return emitError("expected gpu.yield op in region");
632 } else {
633 gpu::AllReduceOperation opName = *getOp();
634 if (failed(verifyReduceOpAndType(opName, getType()))) {
635 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
636 << "` reduction operation is not compatible with type "
637 << getType();
638 }
639 }
640
641 return success();
642}
643
645 auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
646 if (!launchOp)
647 return false;
648
649 Region &body = launchOp.getBody();
650 assert(!body.empty() && "Invalid region");
651
652 // Only convert ops in gpu::launch entry block for now.
653 return op->getBlock() == &body.front();
654}
655
656OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
657 if (!getUniform() && canMakeGroupOpUniform(*this)) {
658 setUniform(true);
659 return getResult();
660 }
661
662 return nullptr;
663}
664
665// TODO: Support optional custom attributes (without dialect prefix).
666static ParseResult parseAllReduceOperation(AsmParser &parser,
667 AllReduceOperationAttr &attr) {
668 StringRef enumStr;
669 if (!parser.parseOptionalKeyword(&enumStr)) {
670 std::optional<AllReduceOperation> op =
671 gpu::symbolizeAllReduceOperation(enumStr);
672 if (!op)
673 return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
674 attr = AllReduceOperationAttr::get(parser.getContext(), *op);
675 }
676 return success();
677}
678
680 AllReduceOperationAttr attr) {
681 if (attr)
682 attr.print(printer);
683}
684
685//===----------------------------------------------------------------------===//
686// SubgroupReduceOp
687//===----------------------------------------------------------------------===//
688
689LogicalResult gpu::SubgroupReduceOp::verify() {
690 Type elemType = getType();
691 if (auto vecTy = dyn_cast<VectorType>(elemType)) {
692 if (vecTy.isScalable())
693 return emitOpError() << "is not compatible with scalable vector types";
694
695 elemType = vecTy.getElementType();
696 }
697
698 gpu::AllReduceOperation opName = getOp();
699 if (failed(verifyReduceOpAndType(opName, elemType))) {
700 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
701 << "` reduction operation is not compatible with type "
702 << getType();
703 }
704
705 auto clusterSize = getClusterSize();
706 if (clusterSize) {
707 uint32_t size = *clusterSize;
708 if (!llvm::isPowerOf2_32(size)) {
709 return emitOpError() << "cluster size " << size
710 << " is not a power of two";
711 }
712 }
713
714 uint32_t stride = getClusterStride();
715 if (stride != 1 && !clusterSize) {
716 return emitOpError() << "cluster stride can only be specified if cluster "
717 "size is specified";
718 }
719 if (!llvm::isPowerOf2_32(stride)) {
720 return emitOpError() << "cluster stride " << stride
721 << " is not a power of two";
722 }
723
724 return success();
725}
726
727OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
728 if (getClusterSize() == 1)
729 return getValue();
730
731 if (!getUniform() && canMakeGroupOpUniform(*this)) {
732 setUniform(true);
733 return getResult();
734 }
735
736 return nullptr;
737}
738
739//===----------------------------------------------------------------------===//
740// AsyncOpInterface
741//===----------------------------------------------------------------------===//
742
744 op->insertOperands(0, {token});
745 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
746 return;
747 auto attrName =
749 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
750
751 // Async dependencies is the only variadic operand.
752 if (!sizeAttr)
753 return;
754
755 SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
756 ++sizes.front();
757 op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
758}
759
760//===----------------------------------------------------------------------===//
761// LaunchOp
762//===----------------------------------------------------------------------===//
763
764void LaunchOp::build(OpBuilder &builder, OperationState &result,
765 Value gridSizeX, Value gridSizeY, Value gridSizeZ,
766 Value getBlockSizeX, Value getBlockSizeY,
767 Value getBlockSizeZ, Value dynamicSharedMemorySize,
768 Type asyncTokenType, ValueRange asyncDependencies,
769 TypeRange workgroupAttributions,
770 TypeRange privateAttributions, Value clusterSizeX,
771 Value clusterSizeY, Value clusterSizeZ,
772 FlatSymbolRefAttr module, FlatSymbolRefAttr function) {
773 OpBuilder::InsertionGuard g(builder);
774
775 // Add a WorkGroup attribution attribute. This attribute is required to
776 // identify private attributions in the list of block argguments.
777 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
778 builder.getI64IntegerAttr(workgroupAttributions.size()));
779
780 // Add Op operands.
781 result.addOperands(asyncDependencies);
782 if (asyncTokenType)
783 result.types.push_back(builder.getType<AsyncTokenType>());
784
785 // Add grid and block sizes as op operands, followed by the data operands.
786 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
787 getBlockSizeY, getBlockSizeZ});
788 if (clusterSizeX)
789 result.addOperands(clusterSizeX);
790 if (clusterSizeY)
791 result.addOperands(clusterSizeY);
792 if (clusterSizeZ)
793 result.addOperands(clusterSizeZ);
794 if (dynamicSharedMemorySize)
795 result.addOperands(dynamicSharedMemorySize);
796
797 // Add optional module and function attributes.
798 if (module)
799 result.addAttribute(getModuleAttrName(result.name), module);
800 if (function)
801 result.addAttribute(getFunctionAttrName(result.name), function);
802
803 // Create a kernel body region with kNumConfigRegionAttributes + N memory
804 // attributions, where the first kNumConfigRegionAttributes arguments have
805 // `index` type and the rest have the same types as the data operands.
806 Region *kernelRegion = result.addRegion();
807 Block *body = builder.createBlock(kernelRegion);
808 // TODO: Allow passing in proper locations here.
809 for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
810 body->addArgument(builder.getIndexType(), result.location);
811 // Add WorkGroup & Private attributions to the region arguments.
812 for (Type argTy : workgroupAttributions)
813 body->addArgument(argTy, result.location);
814 for (Type argTy : privateAttributions)
815 body->addArgument(argTy, result.location);
816 // Fill OperandSegmentSize Attribute.
817 SmallVector<int32_t, 11> segmentSizes(11, 1);
818 segmentSizes.front() = asyncDependencies.size();
819 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
820 segmentSizes[7] = clusterSizeX ? 1 : 0;
821 segmentSizes[8] = clusterSizeY ? 1 : 0;
822 segmentSizes[9] = clusterSizeZ ? 1 : 0;
823 result.addAttribute(getOperandSegmentSizeAttr(),
824 builder.getDenseI32ArrayAttr(segmentSizes));
825}
826
827KernelDim3 LaunchOp::getBlockIds() {
828 assert(!getBody().empty() && "LaunchOp body must not be empty.");
829 auto args = getBody().getArguments();
830 return KernelDim3{args[0], args[1], args[2]};
831}
832
833KernelDim3 LaunchOp::getThreadIds() {
834 assert(!getBody().empty() && "LaunchOp body must not be empty.");
835 auto args = getBody().getArguments();
836 return KernelDim3{args[3], args[4], args[5]};
837}
838
839KernelDim3 LaunchOp::getGridSize() {
840 assert(!getBody().empty() && "LaunchOp body must not be empty.");
841 auto args = getBody().getArguments();
842 return KernelDim3{args[6], args[7], args[8]};
843}
844
845KernelDim3 LaunchOp::getBlockSize() {
846 assert(!getBody().empty() && "LaunchOp body must not be empty.");
847 auto args = getBody().getArguments();
848 return KernelDim3{args[9], args[10], args[11]};
849}
850
851std::optional<KernelDim3> LaunchOp::getClusterIds() {
852 assert(!getBody().empty() && "LaunchOp body must not be empty.");
853 if (!hasClusterSize())
854 return std::nullopt;
855 auto args = getBody().getArguments();
856 return KernelDim3{args[12], args[13], args[14]};
857}
858
859std::optional<KernelDim3> LaunchOp::getClusterSize() {
860 assert(!getBody().empty() && "LaunchOp body must not be empty.");
861 if (!hasClusterSize())
862 return std::nullopt;
863 auto args = getBody().getArguments();
864 return KernelDim3{args[15], args[16], args[17]};
865}
866
867KernelDim3 LaunchOp::getGridSizeOperandValues() {
868 auto operands = getOperands().drop_front(getAsyncDependencies().size());
869 return KernelDim3{operands[0], operands[1], operands[2]};
870}
871
872KernelDim3 LaunchOp::getBlockSizeOperandValues() {
873 auto operands = getOperands().drop_front(getAsyncDependencies().size());
874 return KernelDim3{operands[3], operands[4], operands[5]};
875}
876
877std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
878 auto operands = getOperands().drop_front(getAsyncDependencies().size());
879 if (!hasClusterSize())
880 return std::nullopt;
881 return KernelDim3{operands[6], operands[7], operands[8]};
882}
883
884LogicalResult LaunchOp::verify() {
885 if (!(hasClusterSize()) &&
886 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
887 return emitOpError() << "cluster size must be all present";
888 return success();
889}
890
891LogicalResult LaunchOp::verifyRegions() {
892 // Kernel launch takes kNumConfigOperands leading operands for grid/block
893 // sizes and transforms them into kNumConfigRegionAttributes region arguments
894 // for block/thread identifiers and grid/block sizes.
895 if (!getBody().empty()) {
896 if (getBody().getNumArguments() <
897 kNumConfigRegionAttributes + getNumWorkgroupAttributions())
898 return emitOpError("unexpected number of region arguments");
899 }
900
901 // Verify Attributions Address Spaces.
902 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
903 GPUDialect::getWorkgroupAddressSpace())) ||
904 failed(verifyAttributions(getOperation(), getPrivateAttributions(),
905 GPUDialect::getPrivateAddressSpace())))
906 return failure();
907
908 // Block terminators without successors are expected to exit the kernel region
909 // and must be `gpu.terminator`.
910 for (Block &block : getBody()) {
911 if (block.empty())
912 continue;
913 if (block.back().getNumSuccessors() != 0)
914 continue;
915 if (!isa<gpu::TerminatorOp>(&block.back())) {
916 return block.back()
917 .emitError()
918 .append("expected '", gpu::TerminatorOp::getOperationName(),
919 "' or a terminator with successors")
920 .attachNote(getLoc())
921 .append("in '", LaunchOp::getOperationName(), "' body region");
922 }
923 }
924
925 if (getNumResults() == 0 && getAsyncToken())
926 return emitOpError("needs to be named when async keyword is specified");
927
928 return success();
929}
930
931// Pretty-print the kernel grid/block size assignment as
932// (%iter-x, %iter-y, %iter-z) in
933// (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
934// where %size-* and %iter-* will correspond to the body region arguments.
936 KernelDim3 operands, KernelDim3 ids) {
937 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
938 p << size.x << " = " << operands.x << ", ";
939 p << size.y << " = " << operands.y << ", ";
940 p << size.z << " = " << operands.z << ')';
941}
942
943void LaunchOp::print(OpAsmPrinter &p) {
944 if (getAsyncToken()) {
945 p << " async";
946 if (!getAsyncDependencies().empty())
947 p << " [" << getAsyncDependencies() << ']';
948 }
949 // Print the launch configuration.
950 if (hasClusterSize()) {
951 p << ' ' << getClustersKeyword();
952 printSizeAssignment(p, getClusterSize().value(),
953 getClusterSizeOperandValues().value(),
954 getClusterIds().value());
955 }
956 p << ' ' << getBlocksKeyword();
957 printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
958 getBlockIds());
959 p << ' ' << getThreadsKeyword();
960 printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
961 getThreadIds());
962 if (getDynamicSharedMemorySize())
963 p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
964 << getDynamicSharedMemorySize();
965
966 // Print optional module attribute.
967 StringRef moduleAttrName = getModuleAttrName();
968 if (auto module = getModule()) {
969 p << ' ' << moduleAttrName << '(';
970 p.printSymbolName(*module);
971 p << ')';
972 }
973 // Print optional function attribute.
974 StringRef functionAttrName = getFunctionAttrName();
975 if (auto function = getFunction()) {
976 p << ' ' << functionAttrName << '(';
977 p.printSymbolName(*function);
978 p << ')';
979 }
980
981 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
982 printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
983
984 p << ' ';
985
986 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
987 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
988 LaunchOp::getOperandSegmentSizeAttr(),
989 getNumWorkgroupAttributionsAttrName(),
990 moduleAttrName, functionAttrName});
991}
992
993// Parse the size assignment blocks for blocks and threads. These have the form
994// (%region_arg, %region_arg, %region_arg) in
995// (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
996// where %region_arg are percent-identifiers for the region arguments to be
997// introduced further (SSA defs), and %operand are percent-identifiers for the
998// SSA value uses.
999static ParseResult
1004 assert(indices.size() == 3 && "space for three indices expected");
1007 /*allowResultNumber=*/false) ||
1008 parser.parseKeyword("in") || parser.parseLParen())
1009 return failure();
1010 std::move(args.begin(), args.end(), indices.begin());
1011
1012 for (int i = 0; i < 3; ++i) {
1013 if (i != 0 && parser.parseComma())
1014 return failure();
1015 if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
1016 parser.parseEqual() || parser.parseOperand(sizes[i]))
1017 return failure();
1018 }
1019
1020 return parser.parseRParen();
1021}
1022
1023/// Parses a Launch operation.
1024/// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
1025/// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
1026/// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
1027/// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
1028/// (`dynamic_shared_memory_size` ssa-use)?
1029/// (`module(` symbol-ref-id `)`)?
1030/// (`function(` symbol-ref-id `)`)?
1031/// memory-attribution
1032/// region attr-dict?
1033/// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
1034ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
1035 // Sizes of the grid and block.
1036 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
1037 sizes(LaunchOp::kNumConfigOperands);
1038
1039 // Region arguments to be created.
1040 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
1041 LaunchOp::kNumConfigRegionAttributes);
1042
1043 // Parse optional async dependencies.
1044 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
1045 Type asyncTokenType;
1046 if (failed(
1047 parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
1048 parser.resolveOperands(asyncDependencies, asyncTokenType,
1049 result.operands))
1050 return failure();
1051 if (parser.getNumResults() > 0) {
1052 if (!asyncTokenType)
1053 return parser.emitError(
1054 parser.getNameLoc(),
1055 "gpu.launch requires 'async' keyword to return a value");
1056 result.types.push_back(asyncTokenType);
1057 }
1058
1059 bool hasCluster = false;
1060 if (succeeded(
1061 parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) {
1062 hasCluster = true;
1063 sizes.resize(9);
1064 regionArgs.resize(18);
1065 }
1066 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
1067 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
1068
1069 // Last three segment assigns the cluster size. In the region argument
1070 // list, this is last 6 arguments.
1071 if (hasCluster) {
1072 if (parseSizeAssignment(parser, sizesRef.drop_front(6),
1073 regionArgsRef.slice(15, 3),
1074 regionArgsRef.slice(12, 3)))
1075 return failure();
1076 }
1077 // Parse the size assignment segments: the first segment assigns grid sizes
1078 // and defines values for block identifiers; the second segment assigns block
1079 // sizes and defines values for thread identifiers. In the region argument
1080 // list, identifiers precede sizes, and block-related values precede
1081 // thread-related values.
1082 if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
1083 parseSizeAssignment(parser, sizesRef.take_front(3),
1084 regionArgsRef.slice(6, 3),
1085 regionArgsRef.slice(0, 3)) ||
1086 parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
1087 parseSizeAssignment(parser, sizesRef.drop_front(3),
1088 regionArgsRef.slice(9, 3),
1089 regionArgsRef.slice(3, 3)) ||
1090 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
1091 result.operands))
1092 return failure();
1093
1094 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
1095 bool hasDynamicSharedMemorySize = false;
1096 if (!parser.parseOptionalKeyword(
1097 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1098 hasDynamicSharedMemorySize = true;
1099 if (parser.parseOperand(dynamicSharedMemorySize) ||
1100 parser.resolveOperand(dynamicSharedMemorySize,
1101 parser.getBuilder().getI32Type(),
1102 result.operands))
1103 return failure();
1104 }
1105
1106 // Parse optional module attribute.
1107 StringRef moduleAttrName = getModuleAttrName(result.name);
1108 if (succeeded(parser.parseOptionalKeyword(moduleAttrName))) {
1109 FlatSymbolRefAttr moduleSymbol;
1110 if (parser.parseLParen() ||
1111 parser.parseAttribute(moduleSymbol, Type(), moduleAttrName,
1112 result.attributes) ||
1113 parser.parseRParen())
1114 return failure();
1115 }
1116 // Parse optional function attribute.
1117 StringRef functionAttrName = getFunctionAttrName(result.name);
1118 if (succeeded(parser.parseOptionalKeyword(functionAttrName))) {
1119 FlatSymbolRefAttr funcSymbol;
1120 if (parser.parseLParen() ||
1121 parser.parseAttribute(funcSymbol, Type(), functionAttrName,
1122 result.attributes) ||
1123 parser.parseRParen())
1124 return failure();
1125 }
1126
1127 // Create the region arguments, it has kNumConfigRegionAttributes arguments
1128 // that correspond to block/thread identifiers and grid/block sizes, all
1129 // having `index` type, a variadic number of WorkGroup Attributions and
1130 // a variadic number of Private Attributions. The number of WorkGroup
1131 // Attributions is stored in the attr with name:
1132 // LaunchOp::getNumWorkgroupAttributionsAttrName().
1133 Type index = parser.getBuilder().getIndexType();
1134 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
1135 LaunchOp::kNumConfigRegionAttributes + 6, index);
1136
1137 SmallVector<OpAsmParser::Argument> regionArguments;
1138 for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1139 OpAsmParser::Argument arg;
1140 arg.ssaName = std::get<0>(ssaValueAndType);
1141 arg.type = std::get<1>(ssaValueAndType);
1142 regionArguments.push_back(arg);
1143 }
1144
1145 Builder &builder = parser.getBuilder();
1146 // Parse workgroup memory attributions.
1147 if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
1148 regionArguments)))
1149 return failure();
1150
1151 // Store the number of operands we just parsed as the number of workgroup
1152 // memory attributions.
1153 unsigned numWorkgroupAttrs = regionArguments.size() -
1154 LaunchOp::kNumConfigRegionAttributes -
1155 (hasCluster ? 6 : 0);
1156 result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1157 builder.getI64IntegerAttr(numWorkgroupAttrs));
1158
1159 // Parse private memory attributions.
1160 if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
1161 regionArguments)))
1162 return failure();
1163
1164 // Introduce the body region and parse it. The region has
1165 // kNumConfigRegionAttributes arguments that correspond to
1166 // block/thread identifiers and grid/block sizes, all having `index` type.
1167 Region *body = result.addRegion();
1168 if (parser.parseRegion(*body, regionArguments) ||
1169 parser.parseOptionalAttrDict(result.attributes))
1170 return failure();
1171
1172 SmallVector<int32_t, 11> segmentSizes(11, 1);
1173 segmentSizes.front() = asyncDependencies.size();
1174
1175 if (!hasCluster) {
1176 segmentSizes[7] = 0;
1177 segmentSizes[8] = 0;
1178 segmentSizes[9] = 0;
1179 }
1180 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1181 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1182 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1183 return success();
1184}
1185
1186/// Simplify the gpu.launch when the range of a thread or block ID is
1187/// trivially known to be one.
1188struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1189 using OpRewritePattern<LaunchOp>::OpRewritePattern;
1190 LogicalResult matchAndRewrite(LaunchOp op,
1191 PatternRewriter &rewriter) const override {
1192 // If the range implies a single value for `id`, replace `id`'s uses by
1193 // zero.
1194 Value zero;
1195 bool simplified = false;
1196 auto constPropIdUses = [&](Value id, Value size) {
1197 // Check if size is trivially one.
1198 if (!matchPattern(size, m_One()))
1199 return;
1200 if (id.getUses().empty())
1201 return;
1202 if (!simplified) {
1203 // Create a zero value the first time.
1204 OpBuilder::InsertionGuard guard(rewriter);
1205 rewriter.setInsertionPointToStart(&op.getBody().front());
1206 zero =
1207 arith::ConstantIndexOp::create(rewriter, op.getLoc(), /*value=*/0);
1208 }
1209 rewriter.replaceAllUsesWith(id, zero);
1210 simplified = true;
1211 };
1212 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1213 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1214 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1215 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1216 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1217 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1218
1219 return success(simplified);
1220 }
1221};
1222
1223void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1224 MLIRContext *context) {
1225 rewrites.add<FoldLaunchArguments>(context);
1226}
1227
1228/// Adds a new block argument that corresponds to buffers located in
1229/// workgroup memory.
1230BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1231 auto attrName = getNumWorkgroupAttributionsAttrName();
1232 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1233 (*this)->setAttr(attrName,
1234 IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1235 return getBody().insertArgument(
1236 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1237}
1238
1239/// Adds a new block argument that corresponds to buffers located in
1240/// private memory.
1241BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1242 // Buffers on the private memory always come after buffers on the workgroup
1243 // memory.
1244 return getBody().addArgument(type, loc);
1245}
1246
1247//===----------------------------------------------------------------------===//
1248// LaunchFuncOp
1249//===----------------------------------------------------------------------===//
1250
1251void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1252 SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1253 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1254 ValueRange kernelOperands, Type asyncTokenType,
1255 ValueRange asyncDependencies,
1256 std::optional<KernelDim3> clusterSize) {
1257 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1258 "expected a symbol reference with a single nested reference");
1259 result.addOperands(asyncDependencies);
1260 if (asyncTokenType)
1261 result.types.push_back(builder.getType<AsyncTokenType>());
1262
1263 // Add grid and block sizes as op operands, followed by the data operands.
1264 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1266 if (clusterSize.has_value())
1267 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1268 if (dynamicSharedMemorySize)
1269 result.addOperands(dynamicSharedMemorySize);
1270 result.addOperands(kernelOperands);
1271
1272 Properties &prop = result.getOrAddProperties<Properties>();
1273 prop.kernel = kernelSymbol;
1274 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1275 // Initialize the segment sizes to 1.
1276 llvm::fill(prop.operandSegmentSizes, 1);
1277 prop.operandSegmentSizes[0] = asyncDependencies.size();
1278 if (!clusterSize.has_value()) {
1279 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1280 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1281 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1282 }
1283 prop.operandSegmentSizes[segmentSizesLen - 3] =
1284 dynamicSharedMemorySize ? 1 : 0;
1285 prop.operandSegmentSizes[segmentSizesLen - 2] =
1286 static_cast<int32_t>(kernelOperands.size());
1287 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1288}
1289
1290void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1291 GPUFuncOp kernelFunc, KernelDim3 gridSize,
1292 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1293 ValueRange kernelOperands, Type asyncTokenType,
1294 ValueRange asyncDependencies,
1295 std::optional<KernelDim3> clusterSize) {
1296 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1297 auto kernelSymbol =
1298 SymbolRefAttr::get(kernelModule.getNameAttr(),
1299 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1300 build(builder, result, kernelSymbol, gridSize, getBlockSize,
1301 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1302 asyncDependencies, clusterSize);
1303}
1304
1305void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1306 SymbolRefAttr kernel, KernelDim3 gridSize,
1307 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1308 ValueRange kernelOperands, Value asyncObject,
1309 std::optional<KernelDim3> clusterSize) {
1310 // Add grid and block sizes as op operands, followed by the data operands.
1311 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1313 if (clusterSize.has_value())
1314 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1315 if (dynamicSharedMemorySize)
1316 result.addOperands(dynamicSharedMemorySize);
1317 result.addOperands(kernelOperands);
1318 if (asyncObject)
1319 result.addOperands(asyncObject);
1320 Properties &prop = result.getOrAddProperties<Properties>();
1321 prop.kernel = kernel;
1322 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1323 // Initialize the segment sizes to 1.
1324 llvm::fill(prop.operandSegmentSizes, 1);
1325 prop.operandSegmentSizes[0] = 0;
1326 if (!clusterSize.has_value()) {
1327 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1328 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1329 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1330 }
1331 prop.operandSegmentSizes[segmentSizesLen - 3] =
1332 dynamicSharedMemorySize ? 1 : 0;
1333 prop.operandSegmentSizes[segmentSizesLen - 2] =
1334 static_cast<int32_t>(kernelOperands.size());
1335 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1336}
1337
1338StringAttr LaunchFuncOp::getKernelModuleName() {
1339 return getKernel().getRootReference();
1340}
1341
1342StringAttr LaunchFuncOp::getKernelName() {
1343 return getKernel().getLeafReference();
1344}
1345
1346unsigned LaunchFuncOp::getNumKernelOperands() {
1347 return getKernelOperands().size();
1348}
1349
1350Value LaunchFuncOp::getKernelOperand(unsigned i) {
1351 return getKernelOperands()[i];
1352}
1353
1354KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1355 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1356 return KernelDim3{operands[0], operands[1], operands[2]};
1357}
1358
1359KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1360 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1361 return KernelDim3{operands[3], operands[4], operands[5]};
1362}
1363
1364KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1365 assert(hasClusterSize() &&
1366 "cluster size is not set, check hasClusterSize() first");
1367 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1368 return KernelDim3{operands[6], operands[7], operands[8]};
1369}
1370
1371LogicalResult LaunchFuncOp::verify() {
1372 auto module = (*this)->getParentOfType<ModuleOp>();
1373 if (!module)
1374 return emitOpError("expected to belong to a module");
1375
1376 if (!module->getAttrOfType<UnitAttr>(
1377 GPUDialect::getContainerModuleAttrName()))
1378 return emitOpError("expected the closest surrounding module to have the '" +
1379 GPUDialect::getContainerModuleAttrName() +
1380 "' attribute");
1381
1382 if (hasClusterSize()) {
1383 if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1384 getClusterSizeZ().getType() != getClusterSizeX().getType())
1385 return emitOpError()
1386 << "expects types of the cluster dimensions must be the same";
1387 }
1388
1389 return success();
1390}
1391
1392static ParseResult
1394 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1395 Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1396 if (succeeded(parser.parseOptionalColon())) {
1397 if (parser.parseType(dimTy))
1398 return failure();
1399 } else {
1400 dimTy = IndexType::get(parser.getContext());
1401 }
1402 if (clusterValue.has_value()) {
1403 clusterXTy = clusterYTy = clusterZTy = dimTy;
1404 }
1405 return success();
1406}
1407
1408static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1409 Value clusterValue, Type clusterXTy,
1410 Type clusterYTy, Type clusterZTy) {
1411 if (!dimTy.isIndex())
1412 printer << ": " << dimTy;
1413}
1414
1415static ParseResult parseLaunchFuncOperands(
1416 OpAsmParser &parser,
1418 SmallVectorImpl<Type> &argTypes) {
1419 if (parser.parseOptionalKeyword("args"))
1420 return success();
1421
1422 auto parseElement = [&]() -> ParseResult {
1423 return failure(parser.parseOperand(argNames.emplace_back()) ||
1424 parser.parseColonType(argTypes.emplace_back()));
1425 };
1426
1428 parseElement, " in argument list");
1429}
1430
1432 OperandRange operands, TypeRange types) {
1433 if (operands.empty())
1434 return;
1435 printer << "args(";
1436 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1437 [&](const auto &pair) {
1438 auto [operand, type] = pair;
1439 printer << operand << " : " << type;
1440 });
1441 printer << ")";
1442}
1443
1444//===----------------------------------------------------------------------===//
1445// ShuffleOp
1446//===----------------------------------------------------------------------===//
1447
1448void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1449 int32_t offset, int32_t width, ShuffleMode mode) {
1450 build(builder, result, value,
1451 arith::ConstantOp::create(builder, result.location,
1452 builder.getI32IntegerAttr(offset)),
1453 arith::ConstantOp::create(builder, result.location,
1454 builder.getI32IntegerAttr(width)),
1455 mode);
1456}
1457
1458//===----------------------------------------------------------------------===//
1459// RotateOp
1460//===----------------------------------------------------------------------===//
1461
1462LogicalResult RotateOp::verify() {
1463 uint32_t offset = getOffset();
1464 uint32_t width = getWidth();
1465
1466 if (offset >= width) {
1467 return emitOpError() << "offset must be in the range [0, " << width << ")";
1468 }
1469
1470 return success();
1471}
1472
1473//===----------------------------------------------------------------------===//
1474// BarrierOp
1475//===----------------------------------------------------------------------===//
1476
1477/// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1478static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1479 PatternRewriter &rewriter) {
1480 auto nextOp = dyn_cast_or_null<BarrierOp>(op->getNextNode());
1481 if (!nextOp)
1482 return failure();
1483
1484 std::optional<ArrayAttr> thisMemfence = op.getAddressSpaces();
1485 std::optional<ArrayAttr> nextMemfence = nextOp.getAddressSpaces();
1486
1487 if (thisMemfence) {
1488 rewriter.modifyOpInPlace(op, [&]() {
1489 if (!nextMemfence) {
1490 op.removeAddressSpacesAttr();
1491 return;
1492 }
1493 // Fast path - merge where the two barriers fence the same spaces.
1494 if (*thisMemfence == *nextMemfence) {
1495 return;
1496 }
1497
1498 llvm::SmallSetVector<Attribute, 4> mergedSpaces;
1499 for (Attribute attr : *thisMemfence)
1500 mergedSpaces.insert(attr);
1501 for (Attribute attr : *nextMemfence)
1502 mergedSpaces.insert(attr);
1503 op.setAddressSpacesAttr(rewriter.getArrayAttr(mergedSpaces.takeVector()));
1504 });
1505 }
1506
1507 rewriter.eraseOp(nextOp);
1508 return success();
1509}
1510
1511void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1512 MLIRContext *context) {
1514}
1515
1516void BarrierOp::build(mlir::OpBuilder &odsBuilder,
1517 mlir::OperationState &odsState,
1518 std::optional<AddressSpace> addressSpace) {
1519 ArrayAttr addressSpacesAttr;
1520 if (addressSpace)
1521 addressSpacesAttr = odsBuilder.getArrayAttr(
1522 AddressSpaceAttr::get(odsBuilder.getContext(), addressSpace.value()));
1523 build(odsBuilder, odsState, addressSpacesAttr);
1524}
1525
1526/// Builds a barrier that causes memory operations affecting `memrefToFence` to
1527/// be completed after the barrier is concluded. Currently, this means setting
1528/// the fenced address spaces to those of the given memref if it is a gpu
1529/// address space.
1530void BarrierOp::build(OpBuilder &builder, OperationState &odsState,
1531 Value memrefToFence) {
1532 std::optional<AddressSpace> addrSpaceToFence;
1533 if (auto memrefType = dyn_cast<BaseMemRefType>(memrefToFence.getType()))
1534 if (auto addrSpaceAttr = dyn_cast_if_present<gpu::AddressSpaceAttr>(
1535 memrefType.getMemorySpace()))
1536 addrSpaceToFence = addrSpaceAttr.getValue();
1537 return build(builder, odsState, addrSpaceToFence);
1538}
1539
1540//===----------------------------------------------------------------------===//
1541// GPUFuncOp
1542//===----------------------------------------------------------------------===//
1543
1544/// Adds a new block argument that corresponds to buffers located in
1545/// workgroup memory.
1546BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1547 auto attrName = getNumWorkgroupAttributionsAttrName();
1548 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1549 (*this)->setAttr(attrName,
1550 IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1551 return getBody().insertArgument(
1552 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1553}
1554
1555/// Adds a new block argument that corresponds to buffers located in
1556/// private memory.
1557BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1558 // Buffers on the private memory always come after buffers on the workgroup
1559 // memory.
1560 return getBody().addArgument(type, loc);
1561}
1562
1563void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1564 StringRef name, FunctionType type,
1565 TypeRange workgroupAttributions,
1566 TypeRange privateAttributions,
1567 ArrayRef<NamedAttribute> attrs) {
1568 OpBuilder::InsertionGuard g(builder);
1569
1571 builder.getStringAttr(name));
1572 result.addAttribute(getFunctionTypeAttrName(result.name),
1573 TypeAttr::get(type));
1574 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1575 builder.getI64IntegerAttr(workgroupAttributions.size()));
1576 result.addAttributes(attrs);
1577 Region *body = result.addRegion();
1578 Block *entryBlock = builder.createBlock(body);
1579
1580 // TODO: Allow passing in proper locations here.
1581 for (Type argTy : type.getInputs())
1582 entryBlock->addArgument(argTy, result.location);
1583 for (Type argTy : workgroupAttributions)
1584 entryBlock->addArgument(argTy, result.location);
1585 for (Type argTy : privateAttributions)
1586 entryBlock->addArgument(argTy, result.location);
1587}
1588
1589/// Parses a GPU function memory attribution.
1590///
1591/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1592/// (`private` `(` ssa-id-and-type-list `)`)?
1593///
1594/// Note that this function parses only one of the two similar parts, with the
1595/// keyword provided as argument.
1596static ParseResult
1597parseAttributions(OpAsmParser &parser, StringRef keyword,
1599 Attribute &attributionAttrs) {
1600 // If we could not parse the keyword, just assume empty list and succeed.
1601 if (failed(parser.parseOptionalKeyword(keyword)))
1602 return success();
1603
1604 size_t existingArgs = args.size();
1605 ParseResult result =
1607 /*allowType=*/true, /*allowAttrs=*/true);
1608 if (failed(result))
1609 return result;
1610
1611 bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1612 [](const OpAsmParser::Argument &arg) -> bool {
1613 return arg.attrs && !arg.attrs.empty();
1614 });
1615 if (!hadAttrs) {
1616 attributionAttrs = nullptr;
1617 return result;
1618 }
1619
1620 Builder &builder = parser.getBuilder();
1621 SmallVector<Attribute> attributionAttrsVec;
1622 for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1623 if (!argument.attrs)
1624 attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1625 else
1626 attributionAttrsVec.push_back(argument.attrs);
1627 }
1628 attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1629 return result;
1630}
1631
1632/// Parses a GPU function.
1633///
1634/// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1635/// (`->` function-result-list)? memory-attribution `kernel`?
1636/// function-attributes? region
1637ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
1638 SmallVector<OpAsmParser::Argument> entryArgs;
1639 SmallVector<DictionaryAttr> resultAttrs;
1640 SmallVector<Type> resultTypes;
1641 bool isVariadic;
1642
1643 // Parse the function name.
1644 StringAttr nameAttr;
1646 result.attributes))
1647 return failure();
1648
1649 auto signatureLocation = parser.getCurrentLocation();
1651 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1652 resultAttrs)))
1653 return failure();
1654
1655 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1656 return parser.emitError(signatureLocation)
1657 << "gpu.func requires named arguments";
1658
1659 // Construct the function type. More types will be added to the region, but
1660 // not to the function type.
1661 Builder &builder = parser.getBuilder();
1662
1663 SmallVector<Type> argTypes;
1664 for (auto &arg : entryArgs)
1665 argTypes.push_back(arg.type);
1666 auto type = builder.getFunctionType(argTypes, resultTypes);
1667 result.addAttribute(getFunctionTypeAttrName(result.name),
1668 TypeAttr::get(type));
1669
1671 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1672 getResAttrsAttrName(result.name));
1673
1674 Attribute workgroupAttributionAttrs;
1675 // Parse workgroup memory attributions.
1676 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1677 entryArgs, workgroupAttributionAttrs)))
1678 return failure();
1679
1680 // Store the number of operands we just parsed as the number of workgroup
1681 // memory attributions.
1682 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1683 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1684 builder.getI64IntegerAttr(numWorkgroupAttrs));
1685 if (workgroupAttributionAttrs)
1686 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1687 workgroupAttributionAttrs);
1688
1689 Attribute privateAttributionAttrs;
1690 // Parse private memory attributions.
1691 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1692 entryArgs, privateAttributionAttrs)))
1693 return failure();
1694 if (privateAttributionAttrs)
1695 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1696 privateAttributionAttrs);
1697
1698 // Parse the kernel attribute if present.
1699 if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1700 result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1701 builder.getUnitAttr());
1702
1703 // Parse attributes.
1704 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1705 return failure();
1706
1707 // Parse the region. If no argument names were provided, take all names
1708 // (including those of attributions) from the entry block.
1709 auto *body = result.addRegion();
1710 return parser.parseRegion(*body, entryArgs);
1711}
1712
1713void GPUFuncOp::print(OpAsmPrinter &p) {
1714 p << ' ';
1715 p.printSymbolName(getName());
1716
1717 FunctionType type = getFunctionType();
1718 function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1719 /*isVariadic=*/false,
1720 type.getResults());
1721
1722 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1723 getWorkgroupAttribAttrs().value_or(nullptr));
1724 printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1725 getPrivateAttribAttrs().value_or(nullptr));
1726 if (isKernel())
1727 p << ' ' << getKernelKeyword();
1728
1730 p, *this,
1731 {getNumWorkgroupAttributionsAttrName(),
1732 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1733 getArgAttrsAttrName(), getResAttrsAttrName(),
1734 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1735 p << ' ';
1736 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1737}
1738
1739static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1740 StringAttr attrName) {
1741 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1742 if (!allAttrs || index >= allAttrs.size())
1743 return DictionaryAttr();
1744 return llvm::cast<DictionaryAttr>(allAttrs[index]);
1745}
1746
1747DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1748 return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1749}
1750
1751DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1752 return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1753}
1754
1755static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1756 DictionaryAttr value, StringAttr attrName) {
1757 MLIRContext *ctx = op.getContext();
1758 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1759 SmallVector<Attribute> elements;
1760 if (allAttrs)
1761 elements.append(allAttrs.begin(), allAttrs.end());
1762 while (elements.size() <= index)
1763 elements.push_back(DictionaryAttr::get(ctx));
1764 if (!value)
1765 elements[index] = DictionaryAttr::get(ctx);
1766 else
1767 elements[index] = value;
1768 ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1769 op->setAttr(attrName, newValue);
1770}
1771
1772void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1773 DictionaryAttr value) {
1774 setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1775}
1776
1777void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1778 DictionaryAttr value) {
1779 setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1780}
1781
1782static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1783 StringAttr name, StringAttr attrsName) {
1784 DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1785 if (!dict)
1786 return Attribute();
1787 return dict.get(name);
1788}
1789
1790Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1791 StringAttr name) {
1792 assert(index < getNumWorkgroupAttributions() &&
1793 "index must map to a workgroup attribution");
1794 return getAttributionAttr(*this, index, name,
1795 getWorkgroupAttribAttrsAttrName());
1796}
1797
1798Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1799 StringAttr name) {
1800 assert(index < getNumPrivateAttributions() &&
1801 "index must map to a private attribution");
1802 return getAttributionAttr(*this, index, name,
1803 getPrivateAttribAttrsAttrName());
1804}
1805
1806static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1807 Attribute value, StringAttr attrsName) {
1808 MLIRContext *ctx = op.getContext();
1810 DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1811 if (oldDict)
1812 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1813
1814 bool found = false;
1815 bool mustSort = true;
1816 for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1817 if (elems[i].getName() == name) {
1818 found = true;
1819 if (!value) {
1820 std::swap(elems[i], elems[elems.size() - 1]);
1821 elems.pop_back();
1822 } else {
1823 mustSort = false;
1824 elems[i] = NamedAttribute(elems[i].getName(), value);
1825 }
1826 break;
1827 }
1828 }
1829 if (!found) {
1830 if (!value)
1831 return;
1832 elems.emplace_back(name, value);
1833 }
1834 if (mustSort) {
1835 DictionaryAttr::sortInPlace(elems);
1836 }
1837 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1838 setAttributionAttrs(op, index, newDict, attrsName);
1839}
1840
1841void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1842 Attribute value) {
1843 assert(index < getNumWorkgroupAttributions() &&
1844 "index must map to a workgroup attribution");
1845 setAttributionAttr(*this, index, name, value,
1846 getWorkgroupAttribAttrsAttrName());
1847}
1848
1849void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1850 Attribute value) {
1851 assert(index < getNumPrivateAttributions() &&
1852 "index must map to a private attribution");
1853 setAttributionAttr(*this, index, name, value,
1854 getPrivateAttribAttrsAttrName());
1855}
1856
1857LogicalResult GPUFuncOp::verifyType() {
1858 if (isKernel() && getFunctionType().getNumResults() != 0)
1859 return emitOpError() << "expected void return type for kernel function";
1860
1861 return success();
1862}
1863
1864/// Verifies the body of the function.
1865LogicalResult GPUFuncOp::verifyBody() {
1866 if (empty())
1867 return emitOpError() << "expected body with at least one block";
1868 unsigned numFuncArguments = getNumArguments();
1869 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1870 unsigned numBlockArguments = front().getNumArguments();
1871 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1872 return emitOpError() << "expected at least "
1873 << numFuncArguments + numWorkgroupAttributions
1874 << " arguments to body region";
1875
1876 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1877 for (unsigned i = 0; i < numFuncArguments; ++i) {
1878 Type blockArgType = front().getArgument(i).getType();
1879 if (funcArgTypes[i] != blockArgType)
1880 return emitOpError() << "expected body region argument #" << i
1881 << " to be of type " << funcArgTypes[i] << ", got "
1882 << blockArgType;
1883 }
1884
1885 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1886 GPUDialect::getWorkgroupAddressSpace())) ||
1887 failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1888 GPUDialect::getPrivateAddressSpace())))
1889 return failure();
1890
1891 return success();
1892}
1893
1894//===----------------------------------------------------------------------===//
1895// ReturnOp
1896//===----------------------------------------------------------------------===//
1897
1898LogicalResult gpu::ReturnOp::verify() {
1899 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1900
1901 FunctionType funType = function.getFunctionType();
1902
1903 if (funType.getNumResults() != getOperands().size())
1904 return emitOpError()
1905 .append("expected ", funType.getNumResults(), " result operands")
1906 .attachNote(function.getLoc())
1907 .append("return type declared here");
1908
1909 for (const auto &pair : llvm::enumerate(
1910 llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1911 auto [type, operand] = pair.value();
1912 if (type != operand.getType())
1913 return emitOpError() << "unexpected type `" << operand.getType()
1914 << "' for operand #" << pair.index();
1915 }
1916 return success();
1917}
1918
1919//===----------------------------------------------------------------------===//
1920// GPUModuleOp
1921//===----------------------------------------------------------------------===//
1922
1923void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1924 StringRef name, ArrayAttr targets,
1925 Attribute offloadingHandler) {
1926 result.addRegion()->emplaceBlock();
1927 Properties &props = result.getOrAddProperties<Properties>();
1928 if (targets)
1929 props.targets = targets;
1930 props.setSymName(builder.getStringAttr(name));
1931 props.offloadingHandler = offloadingHandler;
1932}
1933
1934void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1935 StringRef name, ArrayRef<Attribute> targets,
1936 Attribute offloadingHandler) {
1937 build(builder, result, name,
1938 targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1939 offloadingHandler);
1940}
1941
1942bool GPUModuleOp::hasTarget(Attribute target) {
1943 if (ArrayAttr targets = getTargetsAttr())
1944 return llvm::count(targets.getValue(), target);
1945 return false;
1946}
1947
1948void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1949 ArrayAttr &targetsAttr = getProperties().targets;
1950 SmallVector<Attribute> targetsVector(targets);
1951 targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1952}
1953
1954LogicalResult GPUModuleOp::verify() {
1955 auto targets = getOperation()->getAttrOfType<ArrayAttr>("targets");
1956
1957 if (!targets)
1958 return success();
1959
1960 for (auto target : targets) {
1961 if (auto verifyTargetAttr =
1962 llvm::dyn_cast<TargetAttrVerifyInterface>(target)) {
1963 if (verifyTargetAttr.verifyTarget(getOperation()).failed())
1964 return failure();
1965 }
1966 }
1967 return success();
1968}
1969
1970//===----------------------------------------------------------------------===//
1971// GPUBinaryOp
1972//===----------------------------------------------------------------------===//
1973void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1974 Attribute offloadingHandler, ArrayAttr objects) {
1975 auto &properties = result.getOrAddProperties<Properties>();
1976 result.attributes.push_back(builder.getNamedAttr(
1978 properties.objects = objects;
1979 if (offloadingHandler)
1980 properties.offloadingHandler = offloadingHandler;
1981 else
1982 properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
1983}
1984
1985void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1986 Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1987 build(builder, result, name, offloadingHandler,
1988 objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
1989}
1990
1991static ParseResult parseOffloadingHandler(OpAsmParser &parser,
1992 Attribute &offloadingHandler) {
1993 if (succeeded(parser.parseOptionalLess())) {
1994 if (parser.parseAttribute(offloadingHandler))
1995 return failure();
1996 if (parser.parseGreater())
1997 return failure();
1998 }
1999 if (!offloadingHandler)
2000 offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
2001 return success();
2002}
2003
2005 Attribute offloadingHandler) {
2006 if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
2007 printer << '<' << offloadingHandler << '>';
2008}
2009
2010//===----------------------------------------------------------------------===//
2011// GPUMemcpyOp
2012//===----------------------------------------------------------------------===//
2013
2014LogicalResult MemcpyOp::verify() {
2015 auto srcType = getSrc().getType();
2016 auto dstType = getDst().getType();
2017
2018 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
2019 return emitOpError("arguments have incompatible element type");
2020
2021 if (failed(verifyCompatibleShape(srcType, dstType)))
2022 return emitOpError("arguments have incompatible shape");
2023
2024 return success();
2025}
2026
2027namespace {
2028
2029/// Erases a common case of copy ops where a destination value is used only by
2030/// the copy op, alloc and dealloc ops.
2031struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
2032 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
2033
2034 LogicalResult matchAndRewrite(MemcpyOp op,
2035 PatternRewriter &rewriter) const override {
2036 Value dest = op.getDst();
2037 Operation *destDefOp = dest.getDefiningOp();
2038 // `dest` must be defined by an op having Allocate memory effect in order to
2039 // perform the folding.
2040 if (!destDefOp ||
2042 return failure();
2043 // We can erase `op` iff `dest` has no other use apart from its
2044 // use by `op` and dealloc ops.
2045 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
2046 return user != op &&
2047 !hasSingleEffect<MemoryEffects::Free>(user, dest);
2048 }))
2049 return failure();
2050 // We can perform the folding if and only if op has a single async
2051 // dependency and produces an async token as result, or if it does not have
2052 // any async dependency and does not produce any async token result.
2053 if (op.getAsyncDependencies().size() > 1 ||
2054 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2055 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2056 return failure();
2057 rewriter.replaceOp(op, op.getAsyncDependencies());
2058 return success();
2059 }
2060};
2061
2062} // end anonymous namespace
2063
2064void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2065 MLIRContext *context) {
2066 results.add<EraseTrivialCopyOp>(context);
2067}
2068
2069//===----------------------------------------------------------------------===//
2070// GPU_SubgroupMmaLoadMatrixOp
2071//===----------------------------------------------------------------------===//
2072
2073LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2074 auto srcType = getSrcMemref().getType();
2075 auto resType = getRes().getType();
2076 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2077 auto operand = resMatrixType.getOperand();
2078 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2079
2080 if (!srcMemrefType.isLastDimUnitStride())
2081 return emitError(
2082 "expected source memref most minor dim must have unit stride");
2083
2084 if (operand != "AOp" && operand != "BOp" && operand != "COp")
2085 return emitError("only AOp, BOp and COp can be loaded");
2086
2087 return success();
2088}
2089
2090//===----------------------------------------------------------------------===//
2091// GPU_SubgroupMmaStoreMatrixOp
2092//===----------------------------------------------------------------------===//
2093
2094LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2095 auto srcType = getSrc().getType();
2096 auto dstType = getDstMemref().getType();
2097 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2098 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2099
2100 if (!dstMemrefType.isLastDimUnitStride())
2101 return emitError(
2102 "expected destination memref most minor dim must have unit stride");
2103
2104 if (srcMatrixType.getOperand() != "COp")
2105 return emitError(
2106 "expected the operand matrix being stored to have 'COp' operand type");
2107
2108 return success();
2109}
2110
2111//===----------------------------------------------------------------------===//
2112// GPU_SubgroupMmaComputeOp
2113//===----------------------------------------------------------------------===//
2114
2115LogicalResult SubgroupMmaComputeOp::verify() {
2116 enum OperandMap { A, B, C };
2117 SmallVector<MMAMatrixType, 3> opTypes;
2118 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
2119 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
2120 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
2121
2122 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
2123 opTypes[C].getOperand() != "COp")
2124 return emitError("operands must be in the order AOp, BOp, COp");
2125
2126 ArrayRef<int64_t> aShape, bShape, cShape;
2127 aShape = opTypes[A].getShape();
2128 bShape = opTypes[B].getShape();
2129 cShape = opTypes[C].getShape();
2130
2131 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2132 bShape[1] != cShape[1])
2133 return emitError("operand shapes do not satisfy matmul constraints");
2134
2135 return success();
2136}
2137
2138LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2139 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2140 return memref::foldMemRefCast(*this);
2141}
2142
2143LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2144 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2145 return memref::foldMemRefCast(*this);
2146}
2147
2148//===----------------------------------------------------------------------===//
2149// GPU_WaitOp
2150//===----------------------------------------------------------------------===//
2151
2152namespace {
2153
2154/// Remove gpu.wait op use of gpu.wait op def without async dependencies.
2155/// %t = gpu.wait async [] // No async dependencies.
2156/// ... gpu.wait ... [%t, ...] // %t can be removed.
2157struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2158public:
2160
2161 LogicalResult matchAndRewrite(WaitOp op,
2162 PatternRewriter &rewriter) const final {
2163 auto predicate = [](Value value) {
2164 auto waitOp = value.getDefiningOp<WaitOp>();
2165 return waitOp && waitOp->getNumOperands() == 0;
2166 };
2167 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2168 return failure();
2169 SmallVector<Value> validOperands;
2170 for (Value operand : op->getOperands()) {
2171 if (predicate(operand))
2172 continue;
2173 validOperands.push_back(operand);
2174 }
2175 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2176 return success();
2177 }
2178};
2179
2180/// Simplify trivial gpu.wait ops for the following patterns.
2181/// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2182/// dependencies).
2183/// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2184/// %t0.
2185/// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2186/// dependencies nor return any token.
2187struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2188public:
2190
2191 LogicalResult matchAndRewrite(WaitOp op,
2192 PatternRewriter &rewriter) const final {
2193 // Erase gpu.wait ops that neither have any async dependencies nor return
2194 // any async token.
2195 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2196 rewriter.eraseOp(op);
2197 return success();
2198 }
2199 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2200 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2201 op.getAsyncToken()) {
2202 rewriter.replaceOp(op, op.getAsyncDependencies());
2203 return success();
2204 }
2205 // Erase %t = gpu.wait async ... ops, where %t has no uses.
2206 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2207 rewriter.eraseOp(op);
2208 return success();
2209 }
2210 return failure();
2211 }
2212};
2213
2214} // end anonymous namespace
2215
2216void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2217 MLIRContext *context) {
2218 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2219}
2220
2221//===----------------------------------------------------------------------===//
2222// GPU_AllocOp
2223//===----------------------------------------------------------------------===//
2224
2225LogicalResult AllocOp::verify() {
2226 auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2227
2228 if (failed(verifyDynamicDimensionCount(getOperation(), memRefType,
2229 getDynamicSizes())))
2230 return failure();
2231
2232 unsigned numSymbols = 0;
2233 if (!memRefType.getLayout().isIdentity())
2234 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2235 if (getSymbolOperands().size() != numSymbols) {
2236 return emitOpError(
2237 "symbol operand count does not equal memref symbol count");
2238 }
2239
2240 return success();
2241}
2242
2243namespace {
2244
2245/// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2246/// `memref::AllocOp`.
2247struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2248 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2249
2250 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2251 PatternRewriter &rewriter) const override {
2252 std::optional<int64_t> index = dimOp.getConstantIndex();
2253 if (!index)
2254 return failure();
2255
2256 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2257 if (!memrefType || index.value() >= memrefType.getRank() ||
2258 !memrefType.isDynamicDim(index.value()))
2259 return failure();
2260
2261 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2262 if (!alloc)
2263 return failure();
2264
2265 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2266 memrefType.getDynamicDimIndex(index.value()));
2267 rewriter.replaceOp(dimOp, substituteOp);
2268 return success();
2269 }
2270};
2271
2272} // namespace
2273
2274void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2275 MLIRContext *context) {
2276 results.add<SimplifyDimOfAllocOp>(context);
2277}
2278
2279//===----------------------------------------------------------------------===//
2280// GPU object attribute
2281//===----------------------------------------------------------------------===//
2282
2283LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2284 Attribute target, CompilationTarget format,
2285 StringAttr object, DictionaryAttr properties,
2286 KernelTableAttr kernels) {
2287 if (!target)
2288 return emitError() << "the target attribute cannot be null";
2289 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2290 return success();
2291 return emitError() << "the target attribute must implement or promise the "
2292 "`gpu::TargetAttrInterface`";
2293}
2294
2295namespace {
2296ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2297 StringAttr &object) {
2298 std::optional<CompilationTarget> formatResult;
2299 StringRef enumKeyword;
2300 auto loc = odsParser.getCurrentLocation();
2301 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2302 formatResult = CompilationTarget::Fatbin;
2303 if (!formatResult &&
2304 (formatResult =
2305 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2306 odsParser.parseEqual())
2307 return odsParser.emitError(loc, "expected an equal sign");
2308 if (!formatResult)
2309 return odsParser.emitError(loc, "expected keyword for GPU object format");
2310 FailureOr<StringAttr> objectResult =
2311 FieldParser<StringAttr>::parse(odsParser);
2312 if (failed(objectResult))
2313 return odsParser.emitError(odsParser.getCurrentLocation(),
2314 "failed to parse GPU_ObjectAttr parameter "
2315 "'object' which is to be a `StringAttr`");
2316 format = *formatResult;
2317 object = *objectResult;
2318 return success();
2319}
2320
2321void printObject(AsmPrinter &odsParser, CompilationTarget format,
2322 StringAttr object) {
2323 if (format != CompilationTarget::Fatbin)
2324 odsParser << stringifyEnum(format) << " = ";
2325 odsParser << object;
2326}
2327} // namespace
2328
2329//===----------------------------------------------------------------------===//
2330// GPU select object attribute
2331//===----------------------------------------------------------------------===//
2332
2333LogicalResult
2334gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2335 Attribute target) {
2336 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2337 if (target) {
2338 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2339 if (intAttr.getInt() < 0) {
2340 return emitError() << "the object index must be positive";
2341 }
2342 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2343 return emitError()
2344 << "the target attribute must be a GPU Target attribute";
2345 }
2346 }
2347 return success();
2348}
2349
2350//===----------------------------------------------------------------------===//
2351// DynamicSharedMemoryOp
2352//===----------------------------------------------------------------------===//
2353
2354LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2355 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2356 return emitOpError() << "must be inside an op with symbol table";
2357
2358 MemRefType memrefType = getResultMemref().getType();
2359 // Check address space
2360 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2361 return emitOpError() << "address space must be "
2362 << gpu::AddressSpaceAttr::getMnemonic() << "<"
2363 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2364 }
2365 if (memrefType.hasStaticShape()) {
2366 return emitOpError() << "result memref type must be memref<?xi8, "
2367 "#gpu.address_space<workgroup>>";
2368 }
2369 return success();
2370}
2371
2372//===----------------------------------------------------------------------===//
2373// GPU WarpExecuteOnLane0Op
2374//===----------------------------------------------------------------------===//
2375
2376void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2377 p << "(" << getLaneid() << ")";
2378
2379 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2380 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2381 p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2382
2383 if (!getArgs().empty())
2384 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2385 if (!getResults().empty())
2386 p << " -> (" << getResults().getTypes() << ')';
2387 p << " ";
2388 p.printRegion(getRegion(),
2389 /*printEntryBlockArgs=*/true,
2390 /*printBlockTerminators=*/!getResults().empty());
2391 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2392}
2393
2394ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2395 OperationState &result) {
2396 // Create the region.
2397 result.regions.reserve(1);
2398 Region *warpRegion = result.addRegion();
2399
2400 auto &builder = parser.getBuilder();
2401 OpAsmParser::UnresolvedOperand laneId;
2402
2403 // Parse predicate operand.
2404 if (parser.parseLParen() ||
2405 parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2406 parser.parseRParen())
2407 return failure();
2408
2409 int64_t warpSize;
2410 if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2411 parser.parseRSquare())
2412 return failure();
2413 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2414 builder.getContext())),
2415 builder.getI64IntegerAttr(warpSize));
2416
2417 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2418 return failure();
2419
2420 llvm::SMLoc inputsOperandsLoc;
2421 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2422 SmallVector<Type> inputTypes;
2423 if (succeeded(parser.parseOptionalKeyword("args"))) {
2424 if (parser.parseLParen())
2425 return failure();
2426
2427 inputsOperandsLoc = parser.getCurrentLocation();
2428 if (parser.parseOperandList(inputsOperands) ||
2429 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2430 return failure();
2431 }
2432 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2433 result.operands))
2434 return failure();
2435
2436 // Parse optional results type list.
2437 if (parser.parseOptionalArrowTypeList(result.types))
2438 return failure();
2439 // Parse the region.
2440 if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2441 /*argTypes=*/{}))
2442 return failure();
2443 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2444
2445 // Parse the optional attribute list.
2446 if (parser.parseOptionalAttrDict(result.attributes))
2447 return failure();
2448 return success();
2449}
2450
2451void WarpExecuteOnLane0Op::getSuccessorRegions(
2452 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2453 if (!point.isParent()) {
2454 regions.push_back(RegionSuccessor::parent());
2455 return;
2456 }
2457
2458 // The warp region is always executed
2459 regions.push_back(RegionSuccessor(&getWarpRegion()));
2460}
2461
2462ValueRange WarpExecuteOnLane0Op::getSuccessorInputs(RegionSuccessor successor) {
2463 return successor.isParent() ? ValueRange(getResults()) : ValueRange();
2464}
2465void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2466 TypeRange resultTypes, Value laneId,
2467 int64_t warpSize) {
2468 build(builder, result, resultTypes, laneId, warpSize,
2469 /*operands=*/{}, /*argTypes=*/{});
2470}
2471
2472void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2473 TypeRange resultTypes, Value laneId,
2474 int64_t warpSize, ValueRange args,
2475 TypeRange blockArgTypes) {
2476 result.addOperands(laneId);
2477 result.addAttribute(getAttributeNames()[0],
2478 builder.getI64IntegerAttr(warpSize));
2479 result.addTypes(resultTypes);
2480 result.addOperands(args);
2481 assert(args.size() == blockArgTypes.size());
2482 OpBuilder::InsertionGuard guard(builder);
2483 Region *warpRegion = result.addRegion();
2484 Block *block = builder.createBlock(warpRegion);
2485 for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2486 block->addArgument(type, arg.getLoc());
2487}
2488
2489/// Helper check if the distributed vector type is consistent with the expanded
2490/// type and distributed size.
2491static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2492 int64_t warpSize, Operation *op) {
2493 // If the types matches there is no distribution.
2494 if (expanded == distributed)
2495 return success();
2496 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2497 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2498 if (!expandedVecType || !distributedVecType)
2499 return op->emitOpError("expected vector type for distributed operands.");
2500 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2501 expandedVecType.getElementType() != distributedVecType.getElementType())
2502 return op->emitOpError(
2503 "expected distributed vectors to have same rank and element type.");
2504
2505 SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2506 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2507 int64_t eDim = expandedVecType.getDimSize(i);
2508 int64_t dDim = distributedVecType.getDimSize(i);
2509 if (eDim == dDim)
2510 continue;
2511 if (eDim % dDim != 0)
2512 return op->emitOpError()
2513 << "expected expanded vector dimension #" << i << " (" << eDim
2514 << ") to be a multipler of the distributed vector dimension ("
2515 << dDim << ")";
2516 scales[i] = eDim / dDim;
2517 }
2518 if (llvm::product_of(scales) != warpSize)
2519 return op->emitOpError()
2520 << "incompatible distribution dimensions from " << expandedVecType
2521 << " to " << distributedVecType << " with warp size = " << warpSize;
2522
2523 return success();
2524}
2525
2526LogicalResult WarpExecuteOnLane0Op::verify() {
2527 if (getArgs().size() != getWarpRegion().getNumArguments())
2528 return emitOpError(
2529 "expected same number op arguments and block arguments.");
2530 gpu::YieldOp yield = getTerminator();
2531 if (yield.getNumOperands() != getNumResults())
2532 return emitOpError(
2533 "expected same number of yield operands and return values.");
2534 int64_t warpSize = getWarpSize();
2535 for (auto [regionArg, arg] :
2536 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2537 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2538 warpSize, getOperation())))
2539 return failure();
2540 }
2541 for (auto [yieldOperand, result] :
2542 llvm::zip_equal(yield.getOperands(), getResults())) {
2543 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2544 warpSize, getOperation())))
2545 return failure();
2546 }
2547 return success();
2548}
2549bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2550 return succeeded(
2551 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2552}
2553
2554gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2555 return cast<gpu::YieldOp>(getBody()->getTerminator());
2556}
2557
2558//===----------------------------------------------------------------------===//
2559// GPU_SubgroupBroadcastOp
2560//===----------------------------------------------------------------------===//
2561
2562void gpu::SubgroupBroadcastOp::inferResultRanges(
2563 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
2564 setResultRange(getResult(), argRanges.front());
2565}
2566
2567Speculation::Speculatability gpu::SubgroupBroadcastOp::getSpeculatability() {
2568 switch (getBroadcastType()) {
2569 case BroadcastType::first_active_lane:
2570 // Cannot speculate first_lane broadcast, because speculating it across
2571 // control flow can change the active lanes.
2573 case BroadcastType::specific_lane:
2574 // Speculation should be safe as long as we inside structured control flow.
2576 }
2577 llvm_unreachable("Unknown BroadcastType");
2578}
2579
2580LogicalResult gpu::SubgroupBroadcastOp::verify() {
2581 switch (getBroadcastType()) {
2582 case BroadcastType::first_active_lane:
2583 if (getLane())
2584 return emitOpError()
2585 << "lane can only be specified for `specific_lane` broadcast";
2586 return success();
2587 case BroadcastType::specific_lane:
2588 if (!getLane())
2589 return emitOpError()
2590 << "lane must be specified for `specific_lane` broadcast";
2591 return success();
2592 }
2593 llvm_unreachable("Unknown BroadcastType");
2594}
2595
2596OpFoldResult gpu::SubgroupBroadcastOp::fold(FoldAdaptor /*adaptor*/) {
2597 // Broadcast result is always uniform.
2598 if (auto prev = getSrc().getDefiningOp<SubgroupBroadcastOp>())
2599 return prev.getResult();
2600
2601 return nullptr;
2602}
2603
2604//===----------------------------------------------------------------------===//
2605// GPU KernelMetadataAttr
2606//===----------------------------------------------------------------------===//
2607
2608KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2609 DictionaryAttr metadata) {
2610 assert(kernel && "invalid kernel");
2611 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2612 kernel.getAllArgAttrs(), metadata);
2613}
2614
2615KernelMetadataAttr
2616KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2617 FunctionOpInterface kernel,
2618 DictionaryAttr metadata) {
2619 assert(kernel && "invalid kernel");
2620 return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2621 kernel.getAllArgAttrs(), metadata);
2622}
2623
2624KernelMetadataAttr
2625KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2626 if (attrs.empty())
2627 return *this;
2628 NamedAttrList attrList;
2629 if (DictionaryAttr dict = getMetadata())
2630 attrList.append(dict);
2631 attrList.append(attrs);
2632 return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2633 attrList.getDictionary(getContext()));
2634}
2635
2636LogicalResult
2637KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2638 StringAttr name, Type functionType,
2639 ArrayAttr argAttrs, DictionaryAttr metadata) {
2640 if (name.empty())
2641 return emitError() << "the kernel name can't be empty";
2642 if (argAttrs) {
2643 if (llvm::any_of(argAttrs, [](Attribute attr) {
2644 return !llvm::isa<DictionaryAttr>(attr);
2645 }))
2646 return emitError()
2647 << "all attributes in the array must be a dictionary attribute";
2648 }
2649 return success();
2650}
2651
2652//===----------------------------------------------------------------------===//
2653// GPU KernelTableAttr
2654//===----------------------------------------------------------------------===//
2655
2656KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2657 ArrayRef<KernelMetadataAttr> kernels,
2658 bool isSorted) {
2659 // Note that `is_sorted` is always only invoked once even with assertions ON.
2660 assert((!isSorted || llvm::is_sorted(kernels)) &&
2661 "expected a sorted kernel array");
2662 // Immediately return the attribute if the array is sorted.
2663 if (isSorted || llvm::is_sorted(kernels))
2664 return Base::get(context, kernels);
2665 // Sort the array.
2666 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2667 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2668 return Base::get(context, kernelsTmp);
2669}
2670
2671KernelTableAttr KernelTableAttr::getChecked(
2672 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2673 ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2674 // Note that `is_sorted` is always only invoked once even with assertions ON.
2675 assert((!isSorted || llvm::is_sorted(kernels)) &&
2676 "expected a sorted kernel array");
2677 // Immediately return the attribute if the array is sorted.
2678 if (isSorted || llvm::is_sorted(kernels))
2679 return Base::getChecked(emitError, context, kernels);
2680 // Sort the array.
2681 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2682 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2683 return Base::getChecked(emitError, context, kernelsTmp);
2684}
2685
2686LogicalResult
2687KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2688 ArrayRef<KernelMetadataAttr> kernels) {
2689 if (kernels.size() < 2)
2690 return success();
2691 // Check that the kernels are uniquely named.
2692 if (std::adjacent_find(kernels.begin(), kernels.end(),
2693 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2694 return l.getName() == r.getName();
2695 }) != kernels.end()) {
2696 return emitError() << "expected all kernels to be uniquely named";
2697 }
2698 return success();
2699}
2700
2701KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2702 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2703 return found ? *iterator : KernelMetadataAttr();
2704}
2705
2706KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2707 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2708 return found ? *iterator : KernelMetadataAttr();
2709}
2710
2711//===----------------------------------------------------------------------===//
2712// GPU target options
2713//===----------------------------------------------------------------------===//
2714
2729
2747
2748TypeID TargetOptions::getTypeID() const { return typeID; }
2749
2750StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2751
2755
2756StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2757
2758StringRef TargetOptions::getELFSection() const { return elfSection; }
2759
2763
2764function_ref<void(llvm::Module &)>
2768
2769function_ref<void(llvm::Module &)>
2773
2774function_ref<void(llvm::Module &)>
2778
2780 return isaCallback;
2781}
2782
2783CompilationTarget TargetOptions::getCompilationTarget() const {
2784 return compilationTarget;
2785}
2786
2788 return CompilationTarget::Fatbin;
2789}
2790
2791std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2793 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2794 llvm::StringSaver stringSaver(options.first);
2795 StringRef opts = cmdOptions;
2796 // For a correct tokenization of the command line options `opts` must be
2797 // unquoted, otherwise the tokenization function returns a single string: the
2798 // unquoted `cmdOptions` -which is not the desired behavior.
2799 // Remove any quotes if they are at the beginning and end of the string:
2800 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2801 opts.consume_front("\""), opts.consume_back("\"");
2802 if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2803 opts.consume_front("'"), opts.consume_back("'");
2804#ifdef _WIN32
2805 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2806 /*MarkEOLs=*/false);
2807#else
2808 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2809 /*MarkEOLs=*/false);
2810#endif // _WIN32
2811 return options;
2812}
2813
2814std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2818
2819std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2821 size_t startPos = cmdOptions.find(startsWith);
2822 if (startPos == std::string::npos)
2823 return {llvm::BumpPtrAllocator(), SmallVector<const char *>()};
2824
2825 auto tokenized =
2826 tokenizeCmdOptions(cmdOptions.substr(startPos + startsWith.size()));
2827 cmdOptions.resize(startPos);
2828 return tokenized;
2829}
2830
2832
2833#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2834#include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2835
2836#define GET_ATTRDEF_CLASSES
2837#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2838
2839#define GET_OP_CLASSES
2840#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2841
2842#include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values, ArrayAttr attributes={})
static LogicalResult verifyDistributedType(Type expanded, Type distributed, int64_t warpSize, Operation *op)
Helper check if the distributed vector type is consistent with the expanded type and distributed size...
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op, PatternRewriter &rewriter)
Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition TypeID.h:323
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
UnitAttr getUnitAttr()
Definition Builders.cpp:98
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:76
IntegerType getI32Type()
Definition Builders.cpp:63
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition Builders.cpp:104
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:94
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:98
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool isParent() const
Returns true if branching from the parent op.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:76
bool isIndex() const
Definition Types.cpp:54
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
function_ref< void(llvm::Module &)> optimizedLlvmIRCallback
Callback invoked with LLVM IR for the device module after LLVM optimizations but before codegen.
function_ref< void(StringRef)> getISACallback() const
Returns the callback invoked with the target ISA for the device, for example PTX assembly.
TypeID getTypeID() const
Returns the typeID.
std::string toolkitPath
Path to the target toolkit.
SymbolTable * getSymbolTable() const
Returns the result of the getSymbolTableCallback callback or a nullptr if no callback was provided.
StringRef getELFSection() const
Returns the ELF section.
StringRef getCmdOptions() const
Returns the command line options.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
function_ref< void(StringRef)> isaCallback
Callback invoked with the target ISA for the device, for example PTX assembly.
CompilationTarget compilationTarget
Compilation process target format.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
function_ref< void(llvm::Module &)> initialLlvmIRCallback
Callback invoked with the initial LLVM IR for the device module.
ArrayRef< Attribute > getLibrariesToLink() const
Returns the LLVM libraries to link to.
TargetOptions(StringRef toolkitPath={}, ArrayRef< Attribute > librariesToLink={}, StringRef cmdOptions={}, StringRef elfSection={}, CompilationTarget compilationTarget=getDefaultCompilationTarget(), function_ref< SymbolTable *()> getSymbolTableCallback={}, function_ref< void(llvm::Module &)> initialLlvmIRCallback={}, function_ref< void(llvm::Module &)> linkedLlvmIRCallback={}, function_ref< void(llvm::Module &)> optimizedLlvmIRCallback={}, function_ref< void(StringRef)> isaCallback={})
Constructor initializing the toolkit path, the list of files to link to, extra command line options,...
function_ref< void(llvm::Module &)> getOptimizedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after LLVM optimizations but before c...
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
StringRef getToolkitPath() const
Returns the toolkit path.
SmallVector< Attribute > librariesToLink
List of files to link with the LLVM module.
function_ref< void(llvm::Module &)> linkedLlvmIRCallback
Callback invoked with LLVM IR for the device module after linking the device libraries.
function_ref< void(llvm::Module &)> getInitialLlvmIRCallback() const
Returns the callback invoked with the initial LLVM IR for the device module.
function_ref< SymbolTable *()> getSymbolTableCallback
Callback for obtaining the parent symbol table of all the GPU modules being serialized.
static CompilationTarget getDefaultCompilationTarget()
Returns the default compilation target: CompilationTarget::Fatbin.
function_ref< void(llvm::Module &)> getLinkedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after linking the device libraries.
std::string elfSection
ELF Section where the binary needs to be located.
CompilationTarget getCompilationTarget() const
Returns the compilation target.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
std::pair< IteratorT, bool > findAttrSorted(IteratorT first, IteratorT last, StringRef name)
Using llvm::lower_bound requires an extra string comparison to check whether the returned iterator po...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition GPUDialect.h:39