MLIR 23.0.0git
GPUDialect.cpp
Go to the documentation of this file.
1//===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the GPU kernel-related dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
20#include "mlir/IR/Attributes.h"
21#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinOps.h"
25#include "mlir/IR/Diagnostics.h"
27#include "mlir/IR/Matchers.h"
30#include "mlir/IR/SymbolTable.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include "llvm/Support/CommandLine.h"
39#include "llvm/Support/ErrorHandling.h"
40#include "llvm/Support/FormatVariadic.h"
41#include "llvm/Support/InterleavedRange.h"
42#include "llvm/Support/StringSaver.h"
43#include <cassert>
44#include <numeric>
45
46using namespace mlir;
47using namespace mlir::gpu;
48
49#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
50
51//===----------------------------------------------------------------------===//
52// GPU Device Mapping Attributes
53//===----------------------------------------------------------------------===//
54
55int64_t GPUBlockMappingAttr::getMappingId() const {
56 return static_cast<int64_t>(getBlock());
57}
58
59bool GPUBlockMappingAttr::isLinearMapping() const {
60 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
61}
62
63int64_t GPUBlockMappingAttr::getRelativeIndex() const {
64 return isLinearMapping()
65 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
66 : getMappingId();
67}
68
69int64_t GPUWarpgroupMappingAttr::getMappingId() const {
70 return static_cast<int64_t>(getWarpgroup());
71}
72
73bool GPUWarpgroupMappingAttr::isLinearMapping() const {
74 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
75}
76
77int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
78 return isLinearMapping()
79 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
80 : getMappingId();
81}
82
83int64_t GPUWarpMappingAttr::getMappingId() const {
84 return static_cast<int64_t>(getWarp());
85}
86
87bool GPUWarpMappingAttr::isLinearMapping() const {
88 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
89}
90
91int64_t GPUWarpMappingAttr::getRelativeIndex() const {
92 return isLinearMapping()
93 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
94 : getMappingId();
95}
96
97int64_t GPUThreadMappingAttr::getMappingId() const {
98 return static_cast<int64_t>(getThread());
99}
100
101bool GPUThreadMappingAttr::isLinearMapping() const {
102 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
103}
104
105int64_t GPUThreadMappingAttr::getRelativeIndex() const {
106 return isLinearMapping()
107 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
108 : getMappingId();
109}
110
111int64_t GPULaneMappingAttr::getMappingId() const {
112 return static_cast<int64_t>(getLane());
113}
114
115bool GPULaneMappingAttr::isLinearMapping() const {
116 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
117}
118
119int64_t GPULaneMappingAttr::getRelativeIndex() const {
120 return isLinearMapping()
121 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
122 : getMappingId();
123}
124
125int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds() const { return 64; }
126
127/// 8 4 0
128/// Example mask : 0 0 0 1 1 0 1 0 0
129///
130/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
131/// Logical id for e.g. 5 (2) constructs filter (1 << 5 - 1).
132///
133/// Example mask : 0 0 0 1 1 0 1 0 0
134/// Example filter: 0 0 0 0 1 1 1 1 1
135/// Intersection : 0 0 0 0 1 0 1 0 0
136/// PopCnt : 2
137Value GPUMappingMaskAttr::createLogicalLinearMappingId(
138 OpBuilder &b, Value physicalLinearMappingId) const {
139 Location loc = physicalLinearMappingId.getLoc();
140 Value mask =
141 arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
142 Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
143 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
144 filter = arith::SubIOp::create(b, loc, filter, one);
145 Value filteredId = arith::AndIOp::create(b, loc, mask, filter);
146 return math::CtPopOp::create(b, loc, filteredId);
147}
148
149/// 8 4 0
150/// Example mask : 0 0 0 1 1 0 1 0 0
151///
152/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
153/// Logical id for e.g. 5 (2) constructs filter (1 << 5).
154///
155/// Example mask : 0 0 0 1 1 0 1 0 0
156/// Example filter: 0 0 0 1 0 0 0 0 0
157/// Intersection : 0 0 0 1 0 0 0 0 0
158/// Cmp : 1
159Value GPUMappingMaskAttr::createIsActiveIdPredicate(
160 OpBuilder &b, Value physicalLinearMappingId) const {
161 Location loc = physicalLinearMappingId.getLoc();
162 Value mask =
163 arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
164 Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
165 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
166 Value filtered = arith::AndIOp::create(b, loc, mask, filter);
167 Value zero = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(0));
168 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ne, filtered,
169 zero);
170}
171
172int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
173 return static_cast<int64_t>(getAddressSpace());
174}
175
176bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
177 llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
178}
179
180int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
181 llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
182}
183
184//===----------------------------------------------------------------------===//
185// MMAMatrixType
186//===----------------------------------------------------------------------===//
187
189 StringRef operand) {
190 return Base::get(elementType.getContext(), shape, elementType, operand);
191}
192
195 ArrayRef<int64_t> shape, Type elementType,
196 StringRef operand) {
197 return Base::getChecked(emitError, elementType.getContext(), shape,
198 elementType, operand);
199}
200
201unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
202
204 return getImpl()->getShape();
205}
206
207Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
208
209StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
210
212 return elementType.isF16() || elementType.isF32() || elementType.isF64() ||
213 elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
214 elementType.isInteger(32);
215}
216
217LogicalResult
219 ArrayRef<int64_t> shape, Type elementType,
220 StringRef operand) {
221 if (operand != "AOp" && operand != "BOp" && operand != "COp")
222 return emitError() << "operand expected to be one of AOp, BOp or COp";
223
224 if (shape.size() != 2)
225 return emitError() << "MMAMatrixType must have exactly two dimensions";
226
227 if (!MMAMatrixType::isValidElementType(elementType))
228 return emitError()
229 << "MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
230
231 return success();
232}
233
234//===----------------------------------------------------------------------===//
235// GPUDialect
236//===----------------------------------------------------------------------===//
237
238bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
239 if (!memorySpace)
240 return false;
241 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
242 return gpuAttr.getValue() == getWorkgroupAddressSpace();
243 return false;
244}
245
246bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
247 Attribute memorySpace = type.getMemorySpace();
248 return isWorkgroupMemoryAddressSpace(memorySpace);
249}
250
251bool GPUDialect::isKernel(Operation *op) {
252 UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
253 return static_cast<bool>(isKernelAttr);
254}
255
256namespace {
257/// This class defines the interface for handling inlining with gpu
258/// operations.
259struct GPUInlinerInterface : public DialectInlinerInterface {
260 using DialectInlinerInterface::DialectInlinerInterface;
261
262 /// All gpu dialect ops can be inlined.
263 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
264 return true;
265 }
266};
267} // namespace
268
269void GPUDialect::initialize() {
270 addTypes<AsyncTokenType>();
271 addTypes<MMAMatrixType>();
272 addTypes<SparseDnTensorHandleType>();
273 addTypes<SparseSpMatHandleType>();
274 addTypes<SparseSpGEMMOpHandleType>();
275 addOperations<
276#define GET_OP_LIST
277#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
278 >();
279 addAttributes<
280#define GET_ATTRDEF_LIST
281#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
282 >();
283 addInterfaces<GPUInlinerInterface>();
284 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
285 TerminatorOp>();
286 declarePromisedInterfaces<ValueBoundsOpInterface, ClusterDimOp,
287 ClusterDimBlocksOp, ClusterIdOp, ClusterBlockIdOp,
288 BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp,
289 LaneIdOp, SubgroupIdOp, GlobalIdOp, NumSubgroupsOp,
290 SubgroupSizeOp, LaunchOp, SubgroupBroadcastOp>();
291}
292
293static std::string getSparseHandleKeyword(SparseHandleKind kind) {
294 switch (kind) {
296 return "sparse.dntensor_handle";
298 return "sparse.spmat_handle";
300 return "sparse.spgemmop_handle";
301 }
302 llvm_unreachable("unknown sparse handle kind");
303 return "";
304}
305
306Type GPUDialect::parseType(DialectAsmParser &parser) const {
307 // Parse the main keyword for the type.
308 StringRef keyword;
309 if (parser.parseKeyword(&keyword))
310 return Type();
311 MLIRContext *context = getContext();
312
313 // Handle 'async token' types.
314 if (keyword == "async.token")
315 return AsyncTokenType::get(context);
316
317 if (keyword == "mma_matrix") {
318 SMLoc beginLoc = parser.getNameLoc();
319
320 // Parse '<'.
321 if (parser.parseLess())
322 return nullptr;
323
324 // Parse the size and elementType.
325 SmallVector<int64_t> shape;
326 Type elementType;
327 if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
328 parser.parseType(elementType))
329 return nullptr;
330
331 // Parse ','
332 if (parser.parseComma())
333 return nullptr;
334
335 // Parse operand.
336 std::string operand;
337 if (failed(parser.parseOptionalString(&operand)))
338 return nullptr;
339
340 // Parse '>'.
341 if (parser.parseGreater())
342 return nullptr;
343
345 parser.getEncodedSourceLoc(beginLoc)),
346 shape, elementType, operand);
347 }
348
350 return SparseDnTensorHandleType::get(context);
352 return SparseSpMatHandleType::get(context);
354 return SparseSpGEMMOpHandleType::get(context);
355
356 parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
357 return Type();
358}
359// TODO: print refined type here. Notice that should be corresponding to the
360// parser
361void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
362 TypeSwitch<Type>(type)
363 .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
364 .Case<SparseDnTensorHandleType>([&](Type) {
366 })
367 .Case<SparseSpMatHandleType>(
369 .Case<SparseSpGEMMOpHandleType>([&](Type) {
371 })
372 .Case([&](MMAMatrixType fragTy) {
373 os << "mma_matrix<";
374 auto shape = fragTy.getShape();
375 for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
376 os << *dim << 'x';
377 os << shape.back() << 'x' << fragTy.getElementType();
378 os << ", \"" << fragTy.getOperand() << "\"" << '>';
379 })
380 .DefaultUnreachable("unexpected 'gpu' type kind");
381}
382
383static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
384 NamedAttribute attr) {
385 auto array = dyn_cast<DenseI32ArrayAttr>(attr.getValue());
386 if (!array)
387 return op->emitOpError(Twine(attr.getName()) +
388 " must be a dense i32 array");
389 if (array.size() != 3)
390 return op->emitOpError(Twine(attr.getName()) +
391 " must contain exactly 3 elements");
392 return success();
393}
394
395LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
396 NamedAttribute attr) {
397 if (attr.getName() == getKnownBlockSizeAttrHelper().getName())
398 return verifyKnownLaunchSizeAttr(op, attr);
399 if (attr.getName() == getKnownGridSizeAttrHelper().getName())
400 return verifyKnownLaunchSizeAttr(op, attr);
401 if (attr.getName() == getKnownClusterSizeAttrHelper().getName())
402 return verifyKnownLaunchSizeAttr(op, attr);
403 if (!llvm::isa<UnitAttr>(attr.getValue()) ||
404 attr.getName() != getContainerModuleAttrName())
405 return success();
406
407 auto module = dyn_cast<ModuleOp>(op);
408 if (!module)
409 return op->emitError("expected '")
410 << getContainerModuleAttrName() << "' attribute to be attached to '"
411 << ModuleOp::getOperationName() << '\'';
412
413 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
414 // Ignore launches that are nested more or less deep than functions in the
415 // module we are currently checking.
416 if (!launchOp->getParentOp() ||
417 launchOp->getParentOp()->getParentOp() != module)
418 return success();
419
420 // Ignore launch ops with missing attributes here. The errors will be
421 // reported by the verifiers of those ops.
422 if (!launchOp->getAttrOfType<SymbolRefAttr>(
423 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
424 return success();
425
426 // Check that `launch_func` refers to a well-formed GPU kernel container.
427 StringAttr kernelContainerName = launchOp.getKernelModuleName();
428 Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
429 if (!kernelContainer)
430 return launchOp.emitOpError()
431 << "kernel container '" << kernelContainerName.getValue()
432 << "' is undefined";
433
434 // If the container is a GPU binary op return success.
435 if (isa<BinaryOp>(kernelContainer))
436 return success();
437
438 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
439 if (!kernelModule)
440 return launchOp.emitOpError()
441 << "kernel module '" << kernelContainerName.getValue()
442 << "' is undefined";
443
444 // Check that `launch_func` refers to a well-formed kernel function.
445 Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
446 if (!kernelFunc)
447 return launchOp.emitOpError("kernel function '")
448 << launchOp.getKernel() << "' is undefined";
449 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
450 if (!kernelConvertedFunction) {
451 InFlightDiagnostic diag = launchOp.emitOpError()
452 << "referenced kernel '" << launchOp.getKernel()
453 << "' is not a function";
454 diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
455 return diag;
456 }
457
458 if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
459 GPUDialect::getKernelFuncAttrName()))
460 return launchOp.emitOpError("kernel function is missing the '")
461 << GPUDialect::getKernelFuncAttrName() << "' attribute";
462
463 // TODO: If the kernel isn't a GPU function (which happens during separate
464 // compilation), do not check type correspondence as it would require the
465 // verifier to be aware of the type conversion.
466 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
467 if (!kernelGPUFunction)
468 return success();
469
470 unsigned actualNumArguments = launchOp.getNumKernelOperands();
471 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
472 if (expectedNumArguments != actualNumArguments)
473 return launchOp.emitOpError("got ")
474 << actualNumArguments << " kernel operands but expected "
475 << expectedNumArguments;
476
477 auto functionType = kernelGPUFunction.getFunctionType();
478 for (unsigned i = 0; i < expectedNumArguments; ++i) {
479 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
480 return launchOp.emitOpError("type of function argument ")
481 << i << " does not match";
482 }
483 }
484
485 return success();
486 });
487
488 return walkResult.wasInterrupted() ? failure() : success();
489}
490
491/// Parses an optional list of async operands with an optional leading keyword.
492/// (`async`)? (`[` ssa-id-list `]`)?
493///
494/// This method is used by the tablegen assembly format for async ops as well.
495static ParseResult parseAsyncDependencies(
496 OpAsmParser &parser, Type &asyncTokenType,
498 auto loc = parser.getCurrentLocation();
499 if (succeeded(parser.parseOptionalKeyword("async"))) {
500 if (parser.getNumResults() == 0)
501 return parser.emitError(loc, "needs to be named when marked 'async'");
502 asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
503 }
504 return parser.parseOperandList(asyncDependencies,
506}
507
508/// Prints optional async dependencies with its leading keyword.
509/// (`async`)? (`[` ssa-id-list `]`)?
510// Used by the tablegen assembly format for several async ops.
512 Type asyncTokenType,
513 OperandRange asyncDependencies) {
514 if (asyncTokenType)
515 printer << "async";
516 if (asyncDependencies.empty())
517 return;
518 if (asyncTokenType)
519 printer << ' ';
520 printer << llvm::interleaved_array(asyncDependencies);
521}
522
523// GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
524/// Parses a GPU function memory attribution.
525///
526/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
527/// (`private` `(` ssa-id-and-type-list `)`)?
528///
529/// Note that this function parses only one of the two similar parts, with the
530/// keyword provided as argument.
531static ParseResult
532parseAttributions(OpAsmParser &parser, StringRef keyword,
534 // If we could not parse the keyword, just assume empty list and succeed.
535 if (failed(parser.parseOptionalKeyword(keyword)))
536 return success();
537
539 /*allowType=*/true);
540}
541
542static void printAttributions(OpAsmPrinter &p, StringRef keyword,
544 ArrayAttr attributes = {}) {
545 if (values.empty())
546 return;
547
548 p << ' ' << keyword << '(';
549 llvm::interleaveComma(
550 llvm::enumerate(values), p, [&p, attributes](auto pair) {
551 BlockArgument v = pair.value();
552 p << v << " : " << v.getType();
553
554 size_t attributionIndex = pair.index();
555 DictionaryAttr attrs;
556 if (attributes && attributionIndex < attributes.size())
557 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
558 if (attrs)
559 p.printOptionalAttrDict(attrs.getValue());
560 });
561 p << ')';
562}
563
564/// Verifies a GPU function memory attribution.
565static LogicalResult verifyAttributions(Operation *op,
566 ArrayRef<BlockArgument> attributions,
567 gpu::AddressSpace memorySpace) {
568 for (Value v : attributions) {
569 auto type = llvm::dyn_cast<MemRefType>(v.getType());
570 if (!type)
571 return op->emitOpError() << "expected memref type in attribution";
572
573 // We can only verify the address space if it hasn't already been lowered
574 // from the AddressSpaceAttr to a target-specific numeric value.
575 auto addressSpace =
576 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
577 if (!addressSpace)
578 continue;
579 if (addressSpace.getValue() != memorySpace)
580 return op->emitOpError()
581 << "expected memory space " << stringifyAddressSpace(memorySpace)
582 << " in attribution";
583 }
584 return success();
585}
586
587//===----------------------------------------------------------------------===//
588// AllReduceOp
589//===----------------------------------------------------------------------===//
590
591static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
592 Type resType) {
593 using Kind = gpu::AllReduceOperation;
594 if (llvm::is_contained(
595 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
596 opName)) {
597 if (!isa<FloatType>(resType))
598 return failure();
599 }
600
601 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
602 Kind::AND, Kind::OR, Kind::XOR},
603 opName)) {
604 if (!isa<IntegerType>(resType))
605 return failure();
606 }
607
608 return success();
609}
610
611LogicalResult gpu::AllReduceOp::verifyRegions() {
612 if (getBody().empty() != getOp().has_value())
613 return emitError("expected either an op attribute or a non-empty body");
614 if (!getBody().empty()) {
615 if (getBody().getNumArguments() != 2)
616 return emitError("expected two region arguments");
617 for (auto argument : getBody().getArguments()) {
618 if (argument.getType() != getType())
619 return emitError("incorrect region argument type");
620 }
621 unsigned yieldCount = 0;
622 for (Block &block : getBody()) {
623 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
624 if (yield.getNumOperands() != 1)
625 return emitError("expected one gpu.yield operand");
626 if (yield.getOperand(0).getType() != getType())
627 return emitError("incorrect gpu.yield type");
628 ++yieldCount;
629 }
630 }
631 if (yieldCount == 0)
632 return emitError("expected gpu.yield op in region");
633 } else {
634 gpu::AllReduceOperation opName = *getOp();
635 if (failed(verifyReduceOpAndType(opName, getType()))) {
636 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
637 << "` reduction operation is not compatible with type "
638 << getType();
639 }
640 }
641
642 return success();
643}
644
646 auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
647 if (!launchOp)
648 return false;
649
650 Region &body = launchOp.getBody();
651 assert(!body.empty() && "Invalid region");
652
653 // Only convert ops in gpu::launch entry block for now.
654 return op->getBlock() == &body.front();
655}
656
657OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
658 if (!getUniform() && canMakeGroupOpUniform(*this)) {
659 setUniform(true);
660 return getResult();
661 }
662
663 return nullptr;
664}
665
666// TODO: Support optional custom attributes (without dialect prefix).
667static ParseResult parseAllReduceOperation(AsmParser &parser,
668 AllReduceOperationAttr &attr) {
669 StringRef enumStr;
670 if (!parser.parseOptionalKeyword(&enumStr)) {
671 std::optional<AllReduceOperation> op =
672 gpu::symbolizeAllReduceOperation(enumStr);
673 if (!op)
674 return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
675 attr = AllReduceOperationAttr::get(parser.getContext(), *op);
676 }
677 return success();
678}
679
681 AllReduceOperationAttr attr) {
682 if (attr)
683 attr.print(printer);
684}
685
686//===----------------------------------------------------------------------===//
687// SubgroupReduceOp
688//===----------------------------------------------------------------------===//
689
690LogicalResult gpu::SubgroupReduceOp::verify() {
691 Type elemType = getType();
692 if (auto vecTy = dyn_cast<VectorType>(elemType)) {
693 if (vecTy.isScalable())
694 return emitOpError() << "is not compatible with scalable vector types";
695
696 elemType = vecTy.getElementType();
697 }
698
699 gpu::AllReduceOperation opName = getOp();
700 if (failed(verifyReduceOpAndType(opName, elemType))) {
701 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
702 << "` reduction operation is not compatible with type "
703 << getType();
704 }
705
706 auto clusterSize = getClusterSize();
707 if (clusterSize) {
708 uint32_t size = *clusterSize;
709 if (!llvm::isPowerOf2_32(size)) {
710 return emitOpError() << "cluster size " << size
711 << " is not a power of two";
712 }
713 }
714
715 uint32_t stride = getClusterStride();
716 if (stride != 1 && !clusterSize) {
717 return emitOpError() << "cluster stride can only be specified if cluster "
718 "size is specified";
719 }
720 if (!llvm::isPowerOf2_32(stride)) {
721 return emitOpError() << "cluster stride " << stride
722 << " is not a power of two";
723 }
724
725 return success();
726}
727
728OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
729 if (getClusterSize() == 1)
730 return getValue();
731
732 if (!getUniform() && canMakeGroupOpUniform(*this)) {
733 setUniform(true);
734 return getResult();
735 }
736
737 return nullptr;
738}
739
740//===----------------------------------------------------------------------===//
741// AsyncOpInterface
742//===----------------------------------------------------------------------===//
743
745 op->insertOperands(0, {token});
746 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
747 return;
748 auto attrName =
750 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
751
752 // Async dependencies is the only variadic operand.
753 if (!sizeAttr)
754 return;
755
756 SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
757 ++sizes.front();
758 op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
759}
760
761//===----------------------------------------------------------------------===//
762// LaunchOp
763//===----------------------------------------------------------------------===//
764
765void LaunchOp::build(OpBuilder &builder, OperationState &result,
766 Value gridSizeX, Value gridSizeY, Value gridSizeZ,
767 Value getBlockSizeX, Value getBlockSizeY,
768 Value getBlockSizeZ, Value dynamicSharedMemorySize,
769 Type asyncTokenType, ValueRange asyncDependencies,
770 TypeRange workgroupAttributions,
771 TypeRange privateAttributions, Value clusterSizeX,
772 Value clusterSizeY, Value clusterSizeZ,
773 FlatSymbolRefAttr module, FlatSymbolRefAttr function) {
774 OpBuilder::InsertionGuard g(builder);
775
776 // Add a WorkGroup attribution attribute. This attribute is required to
777 // identify private attributions in the list of block argguments.
778 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
779 builder.getI64IntegerAttr(workgroupAttributions.size()));
780
781 // Add Op operands.
782 result.addOperands(asyncDependencies);
783 if (asyncTokenType)
784 result.types.push_back(builder.getType<AsyncTokenType>());
785
786 // Add grid and block sizes as op operands, followed by the data operands.
787 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
788 getBlockSizeY, getBlockSizeZ});
789 if (clusterSizeX)
790 result.addOperands(clusterSizeX);
791 if (clusterSizeY)
792 result.addOperands(clusterSizeY);
793 if (clusterSizeZ)
794 result.addOperands(clusterSizeZ);
795 if (dynamicSharedMemorySize)
796 result.addOperands(dynamicSharedMemorySize);
797
798 // Add optional module and function attributes.
799 if (module)
800 result.addAttribute(getModuleAttrName(result.name), module);
801 if (function)
802 result.addAttribute(getFunctionAttrName(result.name), function);
803
804 // Create a kernel body region with kNumConfigRegionAttributes + N memory
805 // attributions, where the first kNumConfigRegionAttributes arguments have
806 // `index` type and the rest have the same types as the data operands.
807 Region *kernelRegion = result.addRegion();
808 Block *body = builder.createBlock(kernelRegion);
809 // TODO: Allow passing in proper locations here.
810 for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
811 body->addArgument(builder.getIndexType(), result.location);
812 // Add WorkGroup & Private attributions to the region arguments.
813 for (Type argTy : workgroupAttributions)
814 body->addArgument(argTy, result.location);
815 for (Type argTy : privateAttributions)
816 body->addArgument(argTy, result.location);
817 // Fill OperandSegmentSize Attribute.
818 SmallVector<int32_t, 11> segmentSizes(11, 1);
819 segmentSizes.front() = asyncDependencies.size();
820 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
821 segmentSizes[7] = clusterSizeX ? 1 : 0;
822 segmentSizes[8] = clusterSizeY ? 1 : 0;
823 segmentSizes[9] = clusterSizeZ ? 1 : 0;
824 result.addAttribute(getOperandSegmentSizeAttr(),
825 builder.getDenseI32ArrayAttr(segmentSizes));
826}
827
828KernelDim3 LaunchOp::getBlockIds() {
829 assert(!getBody().empty() && "LaunchOp body must not be empty.");
830 auto args = getBody().getArguments();
831 return KernelDim3{args[0], args[1], args[2]};
832}
833
834KernelDim3 LaunchOp::getThreadIds() {
835 assert(!getBody().empty() && "LaunchOp body must not be empty.");
836 auto args = getBody().getArguments();
837 return KernelDim3{args[3], args[4], args[5]};
838}
839
840KernelDim3 LaunchOp::getGridSize() {
841 assert(!getBody().empty() && "LaunchOp body must not be empty.");
842 auto args = getBody().getArguments();
843 return KernelDim3{args[6], args[7], args[8]};
844}
845
846KernelDim3 LaunchOp::getBlockSize() {
847 assert(!getBody().empty() && "LaunchOp body must not be empty.");
848 auto args = getBody().getArguments();
849 return KernelDim3{args[9], args[10], args[11]};
850}
851
852std::optional<KernelDim3> LaunchOp::getClusterIds() {
853 assert(!getBody().empty() && "LaunchOp body must not be empty.");
854 if (!hasClusterSize())
855 return std::nullopt;
856 auto args = getBody().getArguments();
857 return KernelDim3{args[12], args[13], args[14]};
858}
859
860std::optional<KernelDim3> LaunchOp::getClusterSize() {
861 assert(!getBody().empty() && "LaunchOp body must not be empty.");
862 if (!hasClusterSize())
863 return std::nullopt;
864 auto args = getBody().getArguments();
865 return KernelDim3{args[15], args[16], args[17]};
866}
867
868KernelDim3 LaunchOp::getGridSizeOperandValues() {
869 auto operands = getOperands().drop_front(getAsyncDependencies().size());
870 return KernelDim3{operands[0], operands[1], operands[2]};
871}
872
873KernelDim3 LaunchOp::getBlockSizeOperandValues() {
874 auto operands = getOperands().drop_front(getAsyncDependencies().size());
875 return KernelDim3{operands[3], operands[4], operands[5]};
876}
877
878std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
879 auto operands = getOperands().drop_front(getAsyncDependencies().size());
880 if (!hasClusterSize())
881 return std::nullopt;
882 return KernelDim3{operands[6], operands[7], operands[8]};
883}
884
885LogicalResult LaunchOp::verify() {
886 if (!(hasClusterSize()) &&
887 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
888 return emitOpError() << "cluster size must be all present";
889 return success();
890}
891
892LogicalResult LaunchOp::verifyRegions() {
893 // Kernel launch takes kNumConfigOperands leading operands for grid/block
894 // sizes and transforms them into kNumConfigRegionAttributes region arguments
895 // for block/thread identifiers and grid/block sizes.
896 if (getBody().empty()) {
897 return emitOpError("body region is empty");
898 }
899 if (getBody().getNumArguments() <
900 kNumConfigRegionAttributes + getNumWorkgroupAttributions()) {
901 return emitOpError("unexpected number of region arguments");
902 }
903
904 // Verify Attributions Address Spaces.
905 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
906 GPUDialect::getWorkgroupAddressSpace())) ||
907 failed(verifyAttributions(getOperation(), getPrivateAttributions(),
908 GPUDialect::getPrivateAddressSpace())))
909 return failure();
910
911 // Block terminators without successors are expected to exit the kernel region
912 // and must be `gpu.terminator`.
913 for (Block &block : getBody()) {
914 if (block.empty())
915 continue;
916 if (block.back().getNumSuccessors() != 0)
917 continue;
918 if (!isa<gpu::TerminatorOp>(&block.back())) {
919 return block.back()
920 .emitError()
921 .append("expected '", gpu::TerminatorOp::getOperationName(),
922 "' or a terminator with successors")
923 .attachNote(getLoc())
924 .append("in '", LaunchOp::getOperationName(), "' body region");
925 }
926 }
927
928 if (getNumResults() == 0 && getAsyncToken())
929 return emitOpError("needs to be named when async keyword is specified");
930
931 return success();
932}
933
934// Pretty-print the kernel grid/block size assignment as
935// (%iter-x, %iter-y, %iter-z) in
936// (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
937// where %size-* and %iter-* will correspond to the body region arguments.
939 KernelDim3 operands, KernelDim3 ids) {
940 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
941 p << size.x << " = " << operands.x << ", ";
942 p << size.y << " = " << operands.y << ", ";
943 p << size.z << " = " << operands.z << ')';
944}
945
946void LaunchOp::print(OpAsmPrinter &p) {
947 if (getAsyncToken()) {
948 p << " async";
949 if (!getAsyncDependencies().empty())
950 p << " [" << getAsyncDependencies() << ']';
951 }
952 // Print the launch configuration.
953 if (hasClusterSize()) {
954 p << ' ' << getClustersKeyword();
955 printSizeAssignment(p, getClusterSize().value(),
956 getClusterSizeOperandValues().value(),
957 getClusterIds().value());
958 }
959 p << ' ' << getBlocksKeyword();
960 printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
961 getBlockIds());
962 p << ' ' << getThreadsKeyword();
963 printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
964 getThreadIds());
965 if (getDynamicSharedMemorySize())
966 p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
967 << getDynamicSharedMemorySize();
968
969 // Print optional module attribute.
970 StringRef moduleAttrName = getModuleAttrName();
971 if (auto module = getModule()) {
972 p << ' ' << moduleAttrName << '(';
973 p.printSymbolName(*module);
974 p << ')';
975 }
976 // Print optional function attribute.
977 StringRef functionAttrName = getFunctionAttrName();
978 if (auto function = getFunction()) {
979 p << ' ' << functionAttrName << '(';
980 p.printSymbolName(*function);
981 p << ')';
982 }
983
984 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
985 printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
986
987 p << ' ';
988
989 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
990 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
991 LaunchOp::getOperandSegmentSizeAttr(),
992 getNumWorkgroupAttributionsAttrName(),
993 moduleAttrName, functionAttrName});
994}
995
996// Parse the size assignment blocks for blocks and threads. These have the form
997// (%region_arg, %region_arg, %region_arg) in
998// (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
999// where %region_arg are percent-identifiers for the region arguments to be
1000// introduced further (SSA defs), and %operand are percent-identifiers for the
1001// SSA value uses.
1002static ParseResult
1007 StringRef keyword) {
1008 assert(indices.size() == 3 && "space for three indices expected");
1011 /*allowResultNumber=*/false) ||
1012 parser.parseKeyword("in") || parser.parseLParen())
1013 return failure();
1014
1015 if (args.size() != 3) {
1016 return parser.emitError(parser.getNameLoc())
1017 << keyword << " expects 3 arguments, but got " << args.size();
1018 }
1019 std::move(args.begin(), args.end(), indices.begin());
1020
1021 for (int i = 0; i < 3; ++i) {
1022 if (i != 0 && parser.parseComma())
1023 return failure();
1024 if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
1025 parser.parseEqual() || parser.parseOperand(sizes[i]))
1026 return failure();
1027 }
1028
1029 return parser.parseRParen();
1030}
1031
1032/// Parses a Launch operation.
1033/// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
1034/// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
1035/// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
1036/// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
1037/// (`dynamic_shared_memory_size` ssa-use)?
1038/// (`module(` symbol-ref-id `)`)?
1039/// (`function(` symbol-ref-id `)`)?
1040/// memory-attribution
1041/// region attr-dict?
1042/// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
1043ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
1044 // Sizes of the grid and block.
1045 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
1046 sizes(LaunchOp::kNumConfigOperands);
1047
1048 // Region arguments to be created.
1049 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
1050 LaunchOp::kNumConfigRegionAttributes);
1051
1052 // Parse optional async dependencies.
1053 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
1054 Type asyncTokenType;
1055 if (failed(
1056 parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
1057 parser.resolveOperands(asyncDependencies, asyncTokenType,
1058 result.operands))
1059 return failure();
1060 if (parser.getNumResults() > 0) {
1061 if (!asyncTokenType)
1062 return parser.emitError(
1063 parser.getNameLoc(),
1064 "gpu.launch requires 'async' keyword to return a value");
1065 result.types.push_back(asyncTokenType);
1066 }
1067
1068 bool hasCluster = false;
1069 if (succeeded(parser.parseOptionalKeyword(LaunchOp::getClustersKeyword()))) {
1070 hasCluster = true;
1071 sizes.resize(9);
1072 regionArgs.resize(18);
1073 }
1074 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
1075 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
1076
1077 // Last three segment assigns the cluster size. In the region argument
1078 // list, this is last 6 arguments.
1079 if (hasCluster) {
1081 parser, sizesRef.drop_front(6), regionArgsRef.slice(15, 3),
1082 regionArgsRef.slice(12, 3), LaunchOp::getClustersKeyword()))
1083 return failure();
1084 }
1085 // Parse the size assignment segments: the first segment assigns grid sizes
1086 // and defines values for block identifiers; the second segment assigns block
1087 // sizes and defines values for thread identifiers. In the region argument
1088 // list, identifiers precede sizes, and block-related values precede
1089 // thread-related values.
1090 if (parser.parseKeyword(LaunchOp::getBlocksKeyword()) ||
1091 parseSizeAssignment(parser, sizesRef.take_front(3),
1092 regionArgsRef.slice(6, 3), regionArgsRef.slice(0, 3),
1093 LaunchOp::getBlocksKeyword()) ||
1094 parser.parseKeyword(LaunchOp::getThreadsKeyword()) ||
1095 parseSizeAssignment(parser, sizesRef.drop_front(3),
1096 regionArgsRef.slice(9, 3), regionArgsRef.slice(3, 3),
1097 LaunchOp::getThreadsKeyword()) ||
1098 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
1099 result.operands))
1100 return failure();
1101
1102 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
1103 bool hasDynamicSharedMemorySize = false;
1104 if (!parser.parseOptionalKeyword(
1105 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1106 hasDynamicSharedMemorySize = true;
1107 if (parser.parseOperand(dynamicSharedMemorySize) ||
1108 parser.resolveOperand(dynamicSharedMemorySize,
1109 parser.getBuilder().getI32Type(),
1110 result.operands))
1111 return failure();
1112 }
1113
1114 // Parse optional module attribute.
1115 StringRef moduleAttrName = getModuleAttrName(result.name);
1116 if (succeeded(parser.parseOptionalKeyword(moduleAttrName))) {
1117 FlatSymbolRefAttr moduleSymbol;
1118 if (parser.parseLParen() ||
1119 parser.parseAttribute(moduleSymbol, Type(), moduleAttrName,
1120 result.attributes) ||
1121 parser.parseRParen())
1122 return failure();
1123 }
1124 // Parse optional function attribute.
1125 StringRef functionAttrName = getFunctionAttrName(result.name);
1126 if (succeeded(parser.parseOptionalKeyword(functionAttrName))) {
1127 FlatSymbolRefAttr funcSymbol;
1128 if (parser.parseLParen() ||
1129 parser.parseAttribute(funcSymbol, Type(), functionAttrName,
1130 result.attributes) ||
1131 parser.parseRParen())
1132 return failure();
1133 }
1134
1135 // Create the region arguments, it has kNumConfigRegionAttributes arguments
1136 // that correspond to block/thread identifiers and grid/block sizes, all
1137 // having `index` type, a variadic number of WorkGroup Attributions and
1138 // a variadic number of Private Attributions. The number of WorkGroup
1139 // Attributions is stored in the attr with name:
1140 // LaunchOp::getNumWorkgroupAttributionsAttrName().
1141 Type index = parser.getBuilder().getIndexType();
1142 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
1143 LaunchOp::kNumConfigRegionAttributes + 6, index);
1144
1145 SmallVector<OpAsmParser::Argument> regionArguments;
1146 for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1147 OpAsmParser::Argument arg;
1148 arg.ssaName = std::get<0>(ssaValueAndType);
1149 arg.type = std::get<1>(ssaValueAndType);
1150 regionArguments.push_back(arg);
1151 }
1152
1153 Builder &builder = parser.getBuilder();
1154 // Parse workgroup memory attributions.
1155 if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
1156 regionArguments)))
1157 return failure();
1158
1159 // Store the number of operands we just parsed as the number of workgroup
1160 // memory attributions.
1161 unsigned numWorkgroupAttrs = regionArguments.size() -
1162 LaunchOp::kNumConfigRegionAttributes -
1163 (hasCluster ? 6 : 0);
1164 result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1165 builder.getI64IntegerAttr(numWorkgroupAttrs));
1166
1167 // Parse private memory attributions.
1168 if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
1169 regionArguments)))
1170 return failure();
1171
1172 // Introduce the body region and parse it. The region has
1173 // kNumConfigRegionAttributes arguments that correspond to
1174 // block/thread identifiers and grid/block sizes, all having `index` type.
1175 Region *body = result.addRegion();
1176 if (parser.parseRegion(*body, regionArguments) ||
1177 parser.parseOptionalAttrDict(result.attributes))
1178 return failure();
1179
1180 SmallVector<int32_t, 11> segmentSizes(11, 1);
1181 segmentSizes.front() = asyncDependencies.size();
1182
1183 if (!hasCluster) {
1184 segmentSizes[7] = 0;
1185 segmentSizes[8] = 0;
1186 segmentSizes[9] = 0;
1187 }
1188 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1189 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1190 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1191 return success();
1192}
1193
1194/// Simplify the gpu.launch when the range of a thread or block ID is
1195/// trivially known to be one.
1196struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1197 using OpRewritePattern<LaunchOp>::OpRewritePattern;
1198 LogicalResult matchAndRewrite(LaunchOp op,
1199 PatternRewriter &rewriter) const override {
1200 // If the range implies a single value for `id`, replace `id`'s uses by
1201 // zero.
1202 Value zero;
1203 bool simplified = false;
1204 auto constPropIdUses = [&](Value id, Value size) {
1205 // Check if size is trivially one.
1206 if (!matchPattern(size, m_One()))
1207 return;
1208 if (id.getUses().empty())
1209 return;
1210 if (!simplified) {
1211 // Create a zero value the first time.
1212 OpBuilder::InsertionGuard guard(rewriter);
1213 rewriter.setInsertionPointToStart(&op.getBody().front());
1214 zero =
1215 arith::ConstantIndexOp::create(rewriter, op.getLoc(), /*value=*/0);
1216 }
1217 rewriter.replaceAllUsesWith(id, zero);
1218 simplified = true;
1219 };
1220 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1221 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1222 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1223 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1224 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1225 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1226
1227 return success(simplified);
1228 }
1229};
1230
1231void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1232 MLIRContext *context) {
1233 rewrites.add<FoldLaunchArguments>(context);
1234}
1235
1236/// Adds a new block argument that corresponds to buffers located in
1237/// workgroup memory.
1238BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1239 auto attrName = getNumWorkgroupAttributionsAttrName();
1240 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1241 (*this)->setAttr(attrName,
1242 IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1243 return getBody().insertArgument(
1244 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1245}
1246
1247/// Adds a new block argument that corresponds to buffers located in
1248/// private memory.
1249BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1250 // Buffers on the private memory always come after buffers on the workgroup
1251 // memory.
1252 return getBody().addArgument(type, loc);
1253}
1254
1255//===----------------------------------------------------------------------===//
1256// LaunchFuncOp
1257//===----------------------------------------------------------------------===//
1258
1259void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1260 SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1261 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1262 ValueRange kernelOperands, Type asyncTokenType,
1263 ValueRange asyncDependencies,
1264 std::optional<KernelDim3> clusterSize) {
1265 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1266 "expected a symbol reference with a single nested reference");
1267 result.addOperands(asyncDependencies);
1268 if (asyncTokenType)
1269 result.types.push_back(builder.getType<AsyncTokenType>());
1270
1271 // Add grid and block sizes as op operands, followed by the data operands.
1272 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1274 if (clusterSize.has_value())
1275 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1276 if (dynamicSharedMemorySize)
1277 result.addOperands(dynamicSharedMemorySize);
1278 result.addOperands(kernelOperands);
1279
1280 Properties &prop = result.getOrAddProperties<Properties>();
1281 prop.kernel = kernelSymbol;
1282 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1283 // Initialize the segment sizes to 1.
1284 llvm::fill(prop.operandSegmentSizes, 1);
1285 prop.operandSegmentSizes[0] = asyncDependencies.size();
1286 if (!clusterSize.has_value()) {
1287 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1288 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1289 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1290 }
1291 prop.operandSegmentSizes[segmentSizesLen - 3] =
1292 dynamicSharedMemorySize ? 1 : 0;
1293 prop.operandSegmentSizes[segmentSizesLen - 2] =
1294 static_cast<int32_t>(kernelOperands.size());
1295 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1296}
1297
1298void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1299 GPUFuncOp kernelFunc, KernelDim3 gridSize,
1300 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1301 ValueRange kernelOperands, Type asyncTokenType,
1302 ValueRange asyncDependencies,
1303 std::optional<KernelDim3> clusterSize) {
1304 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1305 auto kernelSymbol =
1306 SymbolRefAttr::get(kernelModule.getNameAttr(),
1307 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1308 build(builder, result, kernelSymbol, gridSize, getBlockSize,
1309 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1310 asyncDependencies, clusterSize);
1311}
1312
1313void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1314 SymbolRefAttr kernel, KernelDim3 gridSize,
1315 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1316 ValueRange kernelOperands, Value asyncObject,
1317 std::optional<KernelDim3> clusterSize) {
1318 // Add grid and block sizes as op operands, followed by the data operands.
1319 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1321 if (clusterSize.has_value())
1322 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1323 if (dynamicSharedMemorySize)
1324 result.addOperands(dynamicSharedMemorySize);
1325 result.addOperands(kernelOperands);
1326 if (asyncObject)
1327 result.addOperands(asyncObject);
1328 Properties &prop = result.getOrAddProperties<Properties>();
1329 prop.kernel = kernel;
1330 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1331 // Initialize the segment sizes to 1.
1332 llvm::fill(prop.operandSegmentSizes, 1);
1333 prop.operandSegmentSizes[0] = 0;
1334 if (!clusterSize.has_value()) {
1335 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1336 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1337 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1338 }
1339 prop.operandSegmentSizes[segmentSizesLen - 3] =
1340 dynamicSharedMemorySize ? 1 : 0;
1341 prop.operandSegmentSizes[segmentSizesLen - 2] =
1342 static_cast<int32_t>(kernelOperands.size());
1343 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1344}
1345
1346StringAttr LaunchFuncOp::getKernelModuleName() {
1347 return getKernel().getRootReference();
1348}
1349
1350StringAttr LaunchFuncOp::getKernelName() {
1351 return getKernel().getLeafReference();
1352}
1353
1354unsigned LaunchFuncOp::getNumKernelOperands() {
1355 return getKernelOperands().size();
1356}
1357
1358Value LaunchFuncOp::getKernelOperand(unsigned i) {
1359 return getKernelOperands()[i];
1360}
1361
1362KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1363 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1364 return KernelDim3{operands[0], operands[1], operands[2]};
1365}
1366
1367KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1368 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1369 return KernelDim3{operands[3], operands[4], operands[5]};
1370}
1371
1372KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1373 assert(hasClusterSize() &&
1374 "cluster size is not set, check hasClusterSize() first");
1375 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1376 return KernelDim3{operands[6], operands[7], operands[8]};
1377}
1378
1379LogicalResult LaunchFuncOp::verify() {
1380 auto module = (*this)->getParentOfType<ModuleOp>();
1381 if (!module)
1382 return emitOpError("expected to belong to a module");
1383
1384 if (!module->getAttrOfType<UnitAttr>(
1385 GPUDialect::getContainerModuleAttrName()))
1386 return emitOpError("expected the closest surrounding module to have the '" +
1387 GPUDialect::getContainerModuleAttrName() +
1388 "' attribute");
1389
1390 if (hasClusterSize()) {
1391 if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1392 getClusterSizeZ().getType() != getClusterSizeX().getType())
1393 return emitOpError()
1394 << "expects types of the cluster dimensions must be the same";
1395 }
1396
1397 return success();
1398}
1399
1400static ParseResult
1402 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1403 Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1404 if (succeeded(parser.parseOptionalColon())) {
1405 if (parser.parseType(dimTy))
1406 return failure();
1407 } else {
1408 dimTy = IndexType::get(parser.getContext());
1409 }
1410 if (clusterValue.has_value()) {
1411 clusterXTy = clusterYTy = clusterZTy = dimTy;
1412 }
1413 return success();
1414}
1415
1416static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1417 Value clusterValue, Type clusterXTy,
1418 Type clusterYTy, Type clusterZTy) {
1419 if (!dimTy.isIndex())
1420 printer << ": " << dimTy;
1421}
1422
1423static ParseResult parseLaunchFuncOperands(
1424 OpAsmParser &parser,
1426 SmallVectorImpl<Type> &argTypes) {
1427 if (parser.parseOptionalKeyword("args"))
1428 return success();
1429
1430 auto parseElement = [&]() -> ParseResult {
1431 return failure(parser.parseOperand(argNames.emplace_back()) ||
1432 parser.parseColonType(argTypes.emplace_back()));
1433 };
1434
1436 parseElement, " in argument list");
1437}
1438
1440 OperandRange operands, TypeRange types) {
1441 if (operands.empty())
1442 return;
1443 printer << "args(";
1444 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1445 [&](const auto &pair) {
1446 auto [operand, type] = pair;
1447 printer << operand << " : " << type;
1448 });
1449 printer << ")";
1450}
1451
1452//===----------------------------------------------------------------------===//
1453// ShuffleOp
1454//===----------------------------------------------------------------------===//
1455
1456void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1457 int32_t offset, int32_t width, ShuffleMode mode) {
1458 build(builder, result, value,
1459 arith::ConstantOp::create(builder, result.location,
1460 builder.getI32IntegerAttr(offset)),
1461 arith::ConstantOp::create(builder, result.location,
1462 builder.getI32IntegerAttr(width)),
1463 mode);
1464}
1465
1466//===----------------------------------------------------------------------===//
1467// RotateOp
1468//===----------------------------------------------------------------------===//
1469
1470LogicalResult RotateOp::verify() {
1471 uint32_t offset = getOffset();
1472 uint32_t width = getWidth();
1473
1474 if (offset >= width) {
1475 return emitOpError() << "offset must be in the range [0, " << width << ")";
1476 }
1477
1478 return success();
1479}
1480
1481//===----------------------------------------------------------------------===//
1482// BarrierOp
1483//===----------------------------------------------------------------------===//
1484
1485/// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1486static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1487 PatternRewriter &rewriter) {
1488 auto nextOp = dyn_cast_or_null<BarrierOp>(op->getNextNode());
1489 if (!nextOp)
1490 return failure();
1491
1492 std::optional<ArrayAttr> thisMemfence = op.getAddressSpaces();
1493 std::optional<ArrayAttr> nextMemfence = nextOp.getAddressSpaces();
1494
1495 if (thisMemfence) {
1496 rewriter.modifyOpInPlace(op, [&]() {
1497 if (!nextMemfence) {
1498 op.removeAddressSpacesAttr();
1499 return;
1500 }
1501 // Fast path - merge where the two barriers fence the same spaces.
1502 if (*thisMemfence == *nextMemfence) {
1503 return;
1504 }
1505
1506 llvm::SmallSetVector<Attribute, 4> mergedSpaces;
1507 for (Attribute attr : *thisMemfence)
1508 mergedSpaces.insert(attr);
1509 for (Attribute attr : *nextMemfence)
1510 mergedSpaces.insert(attr);
1511 op.setAddressSpacesAttr(rewriter.getArrayAttr(mergedSpaces.takeVector()));
1512 });
1513 }
1514
1515 rewriter.eraseOp(nextOp);
1516 return success();
1517}
1518
1519void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1520 MLIRContext *context) {
1522}
1523
1524void BarrierOp::build(mlir::OpBuilder &odsBuilder,
1525 mlir::OperationState &odsState,
1526 std::optional<AddressSpace> addressSpace) {
1527 ArrayAttr addressSpacesAttr;
1528 if (addressSpace)
1529 addressSpacesAttr = odsBuilder.getArrayAttr(
1530 AddressSpaceAttr::get(odsBuilder.getContext(), addressSpace.value()));
1531 build(odsBuilder, odsState, addressSpacesAttr);
1532}
1533
1534/// Builds a barrier that causes memory operations affecting `memrefToFence` to
1535/// be completed after the barrier is concluded. Currently, this means setting
1536/// the fenced address spaces to those of the given memref if it is a gpu
1537/// address space.
1538void BarrierOp::build(OpBuilder &builder, OperationState &odsState,
1539 Value memrefToFence) {
1540 std::optional<AddressSpace> addrSpaceToFence;
1541 if (auto memrefType = dyn_cast<BaseMemRefType>(memrefToFence.getType()))
1542 if (auto addrSpaceAttr = dyn_cast_if_present<gpu::AddressSpaceAttr>(
1543 memrefType.getMemorySpace()))
1544 addrSpaceToFence = addrSpaceAttr.getValue();
1545 return build(builder, odsState, addrSpaceToFence);
1546}
1547
1548//===----------------------------------------------------------------------===//
1549// GPUFuncOp
1550//===----------------------------------------------------------------------===//
1551
1552/// Adds a new block argument that corresponds to buffers located in
1553/// workgroup memory.
1554BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1555 auto attrName = getNumWorkgroupAttributionsAttrName();
1556 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1557 (*this)->setAttr(attrName,
1558 IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1559 return getBody().insertArgument(
1560 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1561}
1562
1563/// Adds a new block argument that corresponds to buffers located in
1564/// private memory.
1565BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1566 // Buffers on the private memory always come after buffers on the workgroup
1567 // memory.
1568 return getBody().addArgument(type, loc);
1569}
1570
1571void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1572 StringRef name, FunctionType type,
1573 TypeRange workgroupAttributions,
1574 TypeRange privateAttributions,
1575 ArrayRef<NamedAttribute> attrs) {
1576 OpBuilder::InsertionGuard g(builder);
1577
1579 builder.getStringAttr(name));
1580 result.addAttribute(getFunctionTypeAttrName(result.name),
1581 TypeAttr::get(type));
1582 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1583 builder.getI64IntegerAttr(workgroupAttributions.size()));
1584 result.addAttributes(attrs);
1585 Region *body = result.addRegion();
1586 Block *entryBlock = builder.createBlock(body);
1587
1588 // TODO: Allow passing in proper locations here.
1589 for (Type argTy : type.getInputs())
1590 entryBlock->addArgument(argTy, result.location);
1591 for (Type argTy : workgroupAttributions)
1592 entryBlock->addArgument(argTy, result.location);
1593 for (Type argTy : privateAttributions)
1594 entryBlock->addArgument(argTy, result.location);
1595}
1596
1597/// Parses a GPU function memory attribution.
1598///
1599/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1600/// (`private` `(` ssa-id-and-type-list `)`)?
1601///
1602/// Note that this function parses only one of the two similar parts, with the
1603/// keyword provided as argument.
1604static ParseResult
1605parseAttributions(OpAsmParser &parser, StringRef keyword,
1607 Attribute &attributionAttrs) {
1608 // If we could not parse the keyword, just assume empty list and succeed.
1609 if (failed(parser.parseOptionalKeyword(keyword)))
1610 return success();
1611
1612 size_t existingArgs = args.size();
1613 ParseResult result =
1615 /*allowType=*/true, /*allowAttrs=*/true);
1616 if (failed(result))
1617 return result;
1618
1619 bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1620 [](const OpAsmParser::Argument &arg) -> bool {
1621 return arg.attrs && !arg.attrs.empty();
1622 });
1623 if (!hadAttrs) {
1624 attributionAttrs = nullptr;
1625 return result;
1626 }
1627
1628 Builder &builder = parser.getBuilder();
1629 SmallVector<Attribute> attributionAttrsVec;
1630 for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1631 if (!argument.attrs)
1632 attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1633 else
1634 attributionAttrsVec.push_back(argument.attrs);
1635 }
1636 attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1637 return result;
1638}
1639
1640/// Parses a GPU function.
1641///
1642/// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1643/// (`->` function-result-list)? memory-attribution `kernel`?
1644/// function-attributes? region
1645ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
1646 SmallVector<OpAsmParser::Argument> entryArgs;
1647 SmallVector<DictionaryAttr> resultAttrs;
1648 SmallVector<Type> resultTypes;
1649 bool isVariadic;
1650
1651 // Parse the function name.
1652 StringAttr nameAttr;
1654 result.attributes))
1655 return failure();
1656
1657 auto signatureLocation = parser.getCurrentLocation();
1659 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1660 resultAttrs)))
1661 return failure();
1662
1663 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1664 return parser.emitError(signatureLocation)
1665 << "gpu.func requires named arguments";
1666
1667 // Construct the function type. More types will be added to the region, but
1668 // not to the function type.
1669 Builder &builder = parser.getBuilder();
1670
1671 SmallVector<Type> argTypes;
1672 for (auto &arg : entryArgs)
1673 argTypes.push_back(arg.type);
1674 auto type = builder.getFunctionType(argTypes, resultTypes);
1675 result.addAttribute(getFunctionTypeAttrName(result.name),
1676 TypeAttr::get(type));
1677
1679 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1680 getResAttrsAttrName(result.name));
1681
1682 Attribute workgroupAttributionAttrs;
1683 // Parse workgroup memory attributions.
1684 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1685 entryArgs, workgroupAttributionAttrs)))
1686 return failure();
1687
1688 // Store the number of operands we just parsed as the number of workgroup
1689 // memory attributions.
1690 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1691 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1692 builder.getI64IntegerAttr(numWorkgroupAttrs));
1693 if (workgroupAttributionAttrs)
1694 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1695 workgroupAttributionAttrs);
1696
1697 Attribute privateAttributionAttrs;
1698 // Parse private memory attributions.
1699 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1700 entryArgs, privateAttributionAttrs)))
1701 return failure();
1702 if (privateAttributionAttrs)
1703 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1704 privateAttributionAttrs);
1705
1706 // Parse the kernel attribute if present.
1707 if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1708 result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1709 builder.getUnitAttr());
1710
1711 // Parse attributes.
1712 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1713 return failure();
1714
1715 // Parse the region. If no argument names were provided, take all names
1716 // (including those of attributions) from the entry block.
1717 auto *body = result.addRegion();
1718 return parser.parseRegion(*body, entryArgs);
1719}
1720
1721void GPUFuncOp::print(OpAsmPrinter &p) {
1722 p << ' ';
1723 p.printSymbolName(getName());
1724
1725 FunctionType type = getFunctionType();
1726 function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1727 /*isVariadic=*/false,
1728 type.getResults());
1729
1730 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1731 getWorkgroupAttribAttrs().value_or(nullptr));
1732 printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1733 getPrivateAttribAttrs().value_or(nullptr));
1734 if (isKernel())
1735 p << ' ' << getKernelKeyword();
1736
1738 p, *this,
1739 {getNumWorkgroupAttributionsAttrName(),
1740 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1741 getArgAttrsAttrName(), getResAttrsAttrName(),
1742 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1743 p << ' ';
1744 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1745}
1746
1747static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1748 StringAttr attrName) {
1749 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1750 if (!allAttrs || index >= allAttrs.size())
1751 return DictionaryAttr();
1752 return llvm::cast<DictionaryAttr>(allAttrs[index]);
1753}
1754
1755DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1756 return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1757}
1758
1759DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1760 return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1761}
1762
1763static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1764 DictionaryAttr value, StringAttr attrName) {
1765 MLIRContext *ctx = op.getContext();
1766 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1767 SmallVector<Attribute> elements;
1768 if (allAttrs)
1769 elements.append(allAttrs.begin(), allAttrs.end());
1770 while (elements.size() <= index)
1771 elements.push_back(DictionaryAttr::get(ctx));
1772 if (!value)
1773 elements[index] = DictionaryAttr::get(ctx);
1774 else
1775 elements[index] = value;
1776 ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1777 op->setAttr(attrName, newValue);
1778}
1779
1780void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1781 DictionaryAttr value) {
1782 setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1783}
1784
1785void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1786 DictionaryAttr value) {
1787 setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1788}
1789
1790static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1791 StringAttr name, StringAttr attrsName) {
1792 DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1793 if (!dict)
1794 return Attribute();
1795 return dict.get(name);
1796}
1797
1798Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1799 StringAttr name) {
1800 assert(index < getNumWorkgroupAttributions() &&
1801 "index must map to a workgroup attribution");
1802 return getAttributionAttr(*this, index, name,
1803 getWorkgroupAttribAttrsAttrName());
1804}
1805
1806Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1807 StringAttr name) {
1808 assert(index < getNumPrivateAttributions() &&
1809 "index must map to a private attribution");
1810 return getAttributionAttr(*this, index, name,
1811 getPrivateAttribAttrsAttrName());
1812}
1813
1814static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1815 Attribute value, StringAttr attrsName) {
1816 MLIRContext *ctx = op.getContext();
1818 DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1819 if (oldDict)
1820 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1821
1822 bool found = false;
1823 bool mustSort = true;
1824 for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1825 if (elems[i].getName() == name) {
1826 found = true;
1827 if (!value) {
1828 std::swap(elems[i], elems[elems.size() - 1]);
1829 elems.pop_back();
1830 } else {
1831 mustSort = false;
1832 elems[i] = NamedAttribute(elems[i].getName(), value);
1833 }
1834 break;
1835 }
1836 }
1837 if (!found) {
1838 if (!value)
1839 return;
1840 elems.emplace_back(name, value);
1841 }
1842 if (mustSort) {
1843 DictionaryAttr::sortInPlace(elems);
1844 }
1845 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1846 setAttributionAttrs(op, index, newDict, attrsName);
1847}
1848
1849void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1850 Attribute value) {
1851 assert(index < getNumWorkgroupAttributions() &&
1852 "index must map to a workgroup attribution");
1853 setAttributionAttr(*this, index, name, value,
1854 getWorkgroupAttribAttrsAttrName());
1855}
1856
1857void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1858 Attribute value) {
1859 assert(index < getNumPrivateAttributions() &&
1860 "index must map to a private attribution");
1861 setAttributionAttr(*this, index, name, value,
1862 getPrivateAttribAttrsAttrName());
1863}
1864
1865LogicalResult GPUFuncOp::verifyType() {
1866 if (isKernel() && getFunctionType().getNumResults() != 0)
1867 return emitOpError() << "expected void return type for kernel function";
1868
1869 return success();
1870}
1871
1872/// Verifies the body of the function.
1873LogicalResult GPUFuncOp::verifyBody() {
1874 if (empty())
1875 return emitOpError() << "expected body with at least one block";
1876 unsigned numFuncArguments = getNumArguments();
1877 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1878 unsigned numBlockArguments = front().getNumArguments();
1879 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1880 return emitOpError() << "expected at least "
1881 << numFuncArguments + numWorkgroupAttributions
1882 << " arguments to body region";
1883
1884 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1885 for (unsigned i = 0; i < numFuncArguments; ++i) {
1886 Type blockArgType = front().getArgument(i).getType();
1887 if (funcArgTypes[i] != blockArgType)
1888 return emitOpError() << "expected body region argument #" << i
1889 << " to be of type " << funcArgTypes[i] << ", got "
1890 << blockArgType;
1891 }
1892
1893 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1894 GPUDialect::getWorkgroupAddressSpace())) ||
1895 failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1896 GPUDialect::getPrivateAddressSpace())))
1897 return failure();
1898
1899 return success();
1900}
1901
1902//===----------------------------------------------------------------------===//
1903// ReturnOp
1904//===----------------------------------------------------------------------===//
1905
1906LogicalResult gpu::ReturnOp::verify() {
1907 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1908
1909 FunctionType funType = function.getFunctionType();
1910
1911 if (funType.getNumResults() != getOperands().size())
1912 return emitOpError()
1913 .append("expected ", funType.getNumResults(), " result operands")
1914 .attachNote(function.getLoc())
1915 .append("return type declared here");
1916
1917 for (const auto &pair : llvm::enumerate(
1918 llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1919 auto [type, operand] = pair.value();
1920 if (type != operand.getType())
1921 return emitOpError() << "unexpected type `" << operand.getType()
1922 << "' for operand #" << pair.index();
1923 }
1924 return success();
1925}
1926
1927//===----------------------------------------------------------------------===//
1928// GPUModuleOp
1929//===----------------------------------------------------------------------===//
1930
1931void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1932 StringRef name, ArrayAttr targets,
1933 Attribute offloadingHandler) {
1934 result.addRegion()->emplaceBlock();
1935 Properties &props = result.getOrAddProperties<Properties>();
1936 if (targets)
1937 props.targets = targets;
1938 props.setSymName(builder.getStringAttr(name));
1939 props.offloadingHandler = offloadingHandler;
1940}
1941
1942void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1943 StringRef name, ArrayRef<Attribute> targets,
1944 Attribute offloadingHandler) {
1945 build(builder, result, name,
1946 targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1947 offloadingHandler);
1948}
1949
1950bool GPUModuleOp::hasTarget(Attribute target) {
1951 if (ArrayAttr targets = getTargetsAttr())
1952 return llvm::count(targets.getValue(), target);
1953 return false;
1954}
1955
1956void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1957 ArrayAttr &targetsAttr = getProperties().targets;
1958 SmallVector<Attribute> targetsVector(targets);
1959 targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1960}
1961
1962LogicalResult GPUModuleOp::verify() {
1963 auto targets = getOperation()->getAttrOfType<ArrayAttr>("targets");
1964
1965 if (!targets)
1966 return success();
1967
1968 for (auto target : targets) {
1969 if (auto verifyTargetAttr =
1970 llvm::dyn_cast<TargetAttrVerifyInterface>(target)) {
1971 if (verifyTargetAttr.verifyTarget(getOperation()).failed())
1972 return failure();
1973 }
1974 }
1975 return success();
1976}
1977
1978//===----------------------------------------------------------------------===//
1979// GPUBinaryOp
1980//===----------------------------------------------------------------------===//
1981void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1982 Attribute offloadingHandler, ArrayAttr objects) {
1983 auto &properties = result.getOrAddProperties<Properties>();
1984 result.attributes.push_back(builder.getNamedAttr(
1986 properties.objects = objects;
1987 if (offloadingHandler)
1988 properties.offloadingHandler = offloadingHandler;
1989 else
1990 properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
1991}
1992
1993void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1994 Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1995 build(builder, result, name, offloadingHandler,
1996 objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
1997}
1998
1999static ParseResult parseOffloadingHandler(OpAsmParser &parser,
2000 Attribute &offloadingHandler) {
2001 if (succeeded(parser.parseOptionalLess())) {
2002 if (parser.parseAttribute(offloadingHandler))
2003 return failure();
2004 if (parser.parseGreater())
2005 return failure();
2006 }
2007 if (!offloadingHandler)
2008 offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
2009 return success();
2010}
2011
2013 Attribute offloadingHandler) {
2014 if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
2015 printer << '<' << offloadingHandler << '>';
2016}
2017
2018//===----------------------------------------------------------------------===//
2019// GPUMemcpyOp
2020//===----------------------------------------------------------------------===//
2021
2022LogicalResult MemcpyOp::verify() {
2023 auto srcType = getSrc().getType();
2024 auto dstType = getDst().getType();
2025
2026 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
2027 return emitOpError("arguments have incompatible element type");
2028
2029 if (failed(verifyCompatibleShape(srcType, dstType)))
2030 return emitOpError("arguments have incompatible shape");
2031
2032 return success();
2033}
2034
2035namespace {
2036
2037/// Erases a common case of copy ops where a destination value is used only by
2038/// the copy op, alloc and dealloc ops.
2039struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
2040 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
2041
2042 LogicalResult matchAndRewrite(MemcpyOp op,
2043 PatternRewriter &rewriter) const override {
2044 Value dest = op.getDst();
2045 Operation *destDefOp = dest.getDefiningOp();
2046 // `dest` must be defined by an op having Allocate memory effect in order to
2047 // perform the folding.
2048 if (!destDefOp ||
2050 return failure();
2051 // We can erase `op` iff `dest` has no other use apart from its
2052 // use by `op` and dealloc ops.
2053 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
2054 return user != op &&
2055 !hasSingleEffect<MemoryEffects::Free>(user, dest);
2056 }))
2057 return failure();
2058 // We can perform the folding if and only if op has a single async
2059 // dependency and produces an async token as result, or if it does not have
2060 // any async dependency and does not produce any async token result.
2061 if (op.getAsyncDependencies().size() > 1 ||
2062 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2063 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2064 return failure();
2065 rewriter.replaceOp(op, op.getAsyncDependencies());
2066 return success();
2067 }
2068};
2069
2070} // end anonymous namespace
2071
2072void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2073 MLIRContext *context) {
2074 results.add<EraseTrivialCopyOp>(context);
2075}
2076
2077//===----------------------------------------------------------------------===//
2078// GPU_SubgroupMmaLoadMatrixOp
2079//===----------------------------------------------------------------------===//
2080
2081LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2082 auto srcType = getSrcMemref().getType();
2083 auto resType = getRes().getType();
2084 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2085 auto operand = resMatrixType.getOperand();
2086 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2087
2088 if (!srcMemrefType.isLastDimUnitStride())
2089 return emitError(
2090 "expected source memref most minor dim must have unit stride");
2091
2092 if (operand != "AOp" && operand != "BOp" && operand != "COp")
2093 return emitError("only AOp, BOp and COp can be loaded");
2094
2095 return success();
2096}
2097
2098//===----------------------------------------------------------------------===//
2099// GPU_SubgroupMmaStoreMatrixOp
2100//===----------------------------------------------------------------------===//
2101
2102LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2103 auto srcType = getSrc().getType();
2104 auto dstType = getDstMemref().getType();
2105 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2106 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2107
2108 if (!dstMemrefType.isLastDimUnitStride())
2109 return emitError(
2110 "expected destination memref most minor dim must have unit stride");
2111
2112 if (srcMatrixType.getOperand() != "COp")
2113 return emitError(
2114 "expected the operand matrix being stored to have 'COp' operand type");
2115
2116 return success();
2117}
2118
2119//===----------------------------------------------------------------------===//
2120// GPU_SubgroupMmaComputeOp
2121//===----------------------------------------------------------------------===//
2122
2123LogicalResult SubgroupMmaComputeOp::verify() {
2124 enum OperandMap { A, B, C };
2125 SmallVector<MMAMatrixType, 3> opTypes;
2126 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
2127 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
2128 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
2129
2130 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
2131 opTypes[C].getOperand() != "COp")
2132 return emitError("operands must be in the order AOp, BOp, COp");
2133
2134 ArrayRef<int64_t> aShape, bShape, cShape;
2135 aShape = opTypes[A].getShape();
2136 bShape = opTypes[B].getShape();
2137 cShape = opTypes[C].getShape();
2138
2139 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2140 bShape[1] != cShape[1])
2141 return emitError("operand shapes do not satisfy matmul constraints");
2142
2143 return success();
2144}
2145
2146LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2147 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2148 return memref::foldMemRefCast(*this);
2149}
2150
2151LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2152 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2153 return memref::foldMemRefCast(*this);
2154}
2155
2156//===----------------------------------------------------------------------===//
2157// GPU_WaitOp
2158//===----------------------------------------------------------------------===//
2159
2160namespace {
2161
2162/// Remove gpu.wait op use of gpu.wait op def without async dependencies.
2163/// %t = gpu.wait async [] // No async dependencies.
2164/// ... gpu.wait ... [%t, ...] // %t can be removed.
2165struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2166public:
2168
2169 LogicalResult matchAndRewrite(WaitOp op,
2170 PatternRewriter &rewriter) const final {
2171 auto predicate = [](Value value) {
2172 auto waitOp = value.getDefiningOp<WaitOp>();
2173 return waitOp && waitOp->getNumOperands() == 0;
2174 };
2175 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2176 return failure();
2177 SmallVector<Value> validOperands;
2178 for (Value operand : op->getOperands()) {
2179 if (predicate(operand))
2180 continue;
2181 validOperands.push_back(operand);
2182 }
2183 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2184 return success();
2185 }
2186};
2187
2188/// Simplify trivial gpu.wait ops for the following patterns.
2189/// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2190/// dependencies).
2191/// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2192/// %t0.
2193/// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2194/// dependencies nor return any token.
2195struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2196public:
2198
2199 LogicalResult matchAndRewrite(WaitOp op,
2200 PatternRewriter &rewriter) const final {
2201 // Erase gpu.wait ops that neither have any async dependencies nor return
2202 // any async token.
2203 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2204 rewriter.eraseOp(op);
2205 return success();
2206 }
2207 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2208 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2209 op.getAsyncToken()) {
2210 rewriter.replaceOp(op, op.getAsyncDependencies());
2211 return success();
2212 }
2213 // Erase %t = gpu.wait async ... ops, where %t has no uses.
2214 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2215 rewriter.eraseOp(op);
2216 return success();
2217 }
2218 return failure();
2219 }
2220};
2221
2222} // end anonymous namespace
2223
2224void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2225 MLIRContext *context) {
2226 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2227}
2228
2229//===----------------------------------------------------------------------===//
2230// GPU_AllocOp
2231//===----------------------------------------------------------------------===//
2232
2233LogicalResult AllocOp::verify() {
2234 auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2235
2236 if (failed(verifyDynamicDimensionCount(getOperation(), memRefType,
2237 getDynamicSizes())))
2238 return failure();
2239
2240 unsigned numSymbols = 0;
2241 if (!memRefType.getLayout().isIdentity())
2242 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2243 if (getSymbolOperands().size() != numSymbols) {
2244 return emitOpError(
2245 "symbol operand count does not equal memref symbol count");
2246 }
2247
2248 return success();
2249}
2250
2251namespace {
2252
2253/// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2254/// `memref::AllocOp`.
2255struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2256 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2257
2258 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2259 PatternRewriter &rewriter) const override {
2260 std::optional<int64_t> index = dimOp.getConstantIndex();
2261 if (!index)
2262 return failure();
2263
2264 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2265 if (!memrefType || index.value() >= memrefType.getRank() ||
2266 !memrefType.isDynamicDim(index.value()))
2267 return failure();
2268
2269 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2270 if (!alloc)
2271 return failure();
2272
2273 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2274 memrefType.getDynamicDimIndex(index.value()));
2275 rewriter.replaceOp(dimOp, substituteOp);
2276 return success();
2277 }
2278};
2279
2280} // namespace
2281
2282void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2283 MLIRContext *context) {
2284 results.add<SimplifyDimOfAllocOp>(context);
2285}
2286
2287//===----------------------------------------------------------------------===//
2288// GPU object attribute
2289//===----------------------------------------------------------------------===//
2290
2291LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2292 Attribute target, CompilationTarget format,
2293 StringAttr object, DictionaryAttr properties,
2294 KernelTableAttr kernels) {
2295 if (!target)
2296 return emitError() << "the target attribute cannot be null";
2297 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2298 return success();
2299 return emitError() << "the target attribute must implement or promise the "
2300 "`gpu::TargetAttrInterface`";
2301}
2302
2303namespace {
2304ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2305 StringAttr &object) {
2306 std::optional<CompilationTarget> formatResult;
2307 StringRef enumKeyword;
2308 auto loc = odsParser.getCurrentLocation();
2309 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2310 formatResult = CompilationTarget::Fatbin;
2311 if (!formatResult &&
2312 (formatResult =
2313 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2314 odsParser.parseEqual())
2315 return odsParser.emitError(loc, "expected an equal sign");
2316 if (!formatResult)
2317 return odsParser.emitError(loc, "expected keyword for GPU object format");
2318 FailureOr<StringAttr> objectResult =
2319 FieldParser<StringAttr>::parse(odsParser);
2320 if (failed(objectResult))
2321 return odsParser.emitError(odsParser.getCurrentLocation(),
2322 "failed to parse GPU_ObjectAttr parameter "
2323 "'object' which is to be a `StringAttr`");
2324 format = *formatResult;
2325 object = *objectResult;
2326 return success();
2327}
2328
2329void printObject(AsmPrinter &odsParser, CompilationTarget format,
2330 StringAttr object) {
2331 if (format != CompilationTarget::Fatbin)
2332 odsParser << stringifyEnum(format) << " = ";
2333 odsParser << object;
2334}
2335} // namespace
2336
2337//===----------------------------------------------------------------------===//
2338// GPU select object attribute
2339//===----------------------------------------------------------------------===//
2340
2341LogicalResult
2342gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2343 Attribute target) {
2344 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2345 if (target) {
2346 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2347 if (intAttr.getInt() < 0) {
2348 return emitError() << "the object index must be positive";
2349 }
2350 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2351 return emitError()
2352 << "the target attribute must be a GPU Target attribute";
2353 }
2354 }
2355 return success();
2356}
2357
2358//===----------------------------------------------------------------------===//
2359// DynamicSharedMemoryOp
2360//===----------------------------------------------------------------------===//
2361
2362LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2363 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2364 return emitOpError() << "must be inside an op with symbol table";
2365
2366 MemRefType memrefType = getResultMemref().getType();
2367 // Check address space
2368 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2369 return emitOpError() << "address space must be "
2370 << gpu::AddressSpaceAttr::getMnemonic() << "<"
2371 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2372 }
2373 if (memrefType.hasStaticShape()) {
2374 return emitOpError() << "result memref type must be memref<?xi8, "
2375 "#gpu.address_space<workgroup>>";
2376 }
2377 return success();
2378}
2379
2380//===----------------------------------------------------------------------===//
2381// GPU WarpExecuteOnLane0Op
2382//===----------------------------------------------------------------------===//
2383
2384void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2385 p << "(" << getLaneid() << ")";
2386
2387 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2388 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2389 p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2390
2391 if (!getArgs().empty())
2392 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2393 if (!getResults().empty())
2394 p << " -> (" << getResults().getTypes() << ')';
2395 p << " ";
2396 p.printRegion(getRegion(),
2397 /*printEntryBlockArgs=*/true,
2398 /*printBlockTerminators=*/!getResults().empty());
2399 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2400}
2401
2402ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2403 OperationState &result) {
2404 // Create the region.
2405 result.regions.reserve(1);
2406 Region *warpRegion = result.addRegion();
2407
2408 auto &builder = parser.getBuilder();
2409 OpAsmParser::UnresolvedOperand laneId;
2410
2411 // Parse predicate operand.
2412 if (parser.parseLParen() ||
2413 parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2414 parser.parseRParen())
2415 return failure();
2416
2417 int64_t warpSize;
2418 if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2419 parser.parseRSquare())
2420 return failure();
2421 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2422 builder.getContext())),
2423 builder.getI64IntegerAttr(warpSize));
2424
2425 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2426 return failure();
2427
2428 llvm::SMLoc inputsOperandsLoc;
2429 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2430 SmallVector<Type> inputTypes;
2431 if (succeeded(parser.parseOptionalKeyword("args"))) {
2432 if (parser.parseLParen())
2433 return failure();
2434
2435 inputsOperandsLoc = parser.getCurrentLocation();
2436 if (parser.parseOperandList(inputsOperands) ||
2437 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2438 return failure();
2439 }
2440 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2441 result.operands))
2442 return failure();
2443
2444 // Parse optional results type list.
2445 if (parser.parseOptionalArrowTypeList(result.types))
2446 return failure();
2447 // Parse the region.
2448 if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2449 /*argTypes=*/{}))
2450 return failure();
2451 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2452
2453 // Parse the optional attribute list.
2454 if (parser.parseOptionalAttrDict(result.attributes))
2455 return failure();
2456 return success();
2457}
2458
2459void WarpExecuteOnLane0Op::getSuccessorRegions(
2460 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2461 if (!point.isParent()) {
2462 regions.push_back(RegionSuccessor::parent());
2463 return;
2464 }
2465
2466 // The warp region is always executed
2467 regions.push_back(RegionSuccessor(&getWarpRegion()));
2468}
2469
2470ValueRange WarpExecuteOnLane0Op::getSuccessorInputs(RegionSuccessor successor) {
2471 return successor.isParent() ? ValueRange(getResults()) : ValueRange();
2472}
2473void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2474 TypeRange resultTypes, Value laneId,
2475 int64_t warpSize) {
2476 build(builder, result, resultTypes, laneId, warpSize,
2477 /*operands=*/{}, /*argTypes=*/{});
2478}
2479
2480void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2481 TypeRange resultTypes, Value laneId,
2482 int64_t warpSize, ValueRange args,
2483 TypeRange blockArgTypes) {
2484 result.addOperands(laneId);
2485 result.addAttribute(getAttributeNames()[0],
2486 builder.getI64IntegerAttr(warpSize));
2487 result.addTypes(resultTypes);
2488 result.addOperands(args);
2489 assert(args.size() == blockArgTypes.size());
2490 OpBuilder::InsertionGuard guard(builder);
2491 Region *warpRegion = result.addRegion();
2492 Block *block = builder.createBlock(warpRegion);
2493 for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2494 block->addArgument(type, arg.getLoc());
2495}
2496
2497/// Helper check if the distributed vector type is consistent with the expanded
2498/// type and distributed size.
2499static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2500 int64_t warpSize, Operation *op) {
2501 // If the types matches there is no distribution.
2502 if (expanded == distributed)
2503 return success();
2504 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2505 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2506 if (!expandedVecType || !distributedVecType)
2507 return op->emitOpError("expected vector type for distributed operands.");
2508 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2509 expandedVecType.getElementType() != distributedVecType.getElementType())
2510 return op->emitOpError(
2511 "expected distributed vectors to have same rank and element type.");
2512
2513 SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2514 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2515 int64_t eDim = expandedVecType.getDimSize(i);
2516 int64_t dDim = distributedVecType.getDimSize(i);
2517 if (eDim == dDim)
2518 continue;
2519 if (eDim % dDim != 0)
2520 return op->emitOpError()
2521 << "expected expanded vector dimension #" << i << " (" << eDim
2522 << ") to be a multipler of the distributed vector dimension ("
2523 << dDim << ")";
2524 scales[i] = eDim / dDim;
2525 }
2526 if (llvm::product_of(scales) != warpSize)
2527 return op->emitOpError()
2528 << "incompatible distribution dimensions from " << expandedVecType
2529 << " to " << distributedVecType << " with warp size = " << warpSize;
2530
2531 return success();
2532}
2533
2534LogicalResult WarpExecuteOnLane0Op::verify() {
2535 if (getArgs().size() != getWarpRegion().getNumArguments())
2536 return emitOpError(
2537 "expected same number op arguments and block arguments.");
2538 auto yield = dyn_cast<gpu::YieldOp>(getBody()->getTerminator());
2539 if (!yield)
2540 return emitOpError("expected body to be terminated with 'gpu.yield'");
2541 if (yield.getNumOperands() != getNumResults())
2542 return emitOpError(
2543 "expected same number of yield operands and return values.");
2544 int64_t warpSize = getWarpSize();
2545 for (auto [regionArg, arg] :
2546 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2547 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2548 warpSize, getOperation())))
2549 return failure();
2550 }
2551 for (auto [yieldOperand, result] :
2552 llvm::zip_equal(yield.getOperands(), getResults())) {
2553 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2554 warpSize, getOperation())))
2555 return failure();
2556 }
2557 return success();
2558}
2559bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2560 return succeeded(
2561 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2562}
2563
2564gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2565 return cast<gpu::YieldOp>(getBody()->getTerminator());
2566}
2567
2568//===----------------------------------------------------------------------===//
2569// GPU_SubgroupBroadcastOp
2570//===----------------------------------------------------------------------===//
2571
2572void gpu::SubgroupBroadcastOp::inferResultRanges(
2573 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
2574 setResultRange(getResult(), argRanges.front());
2575}
2576
2577Speculation::Speculatability gpu::SubgroupBroadcastOp::getSpeculatability() {
2578 switch (getBroadcastType()) {
2579 case BroadcastType::first_active_lane:
2580 // Cannot speculate first_lane broadcast, because speculating it across
2581 // control flow can change the active lanes.
2583 case BroadcastType::specific_lane:
2584 // Speculation should be safe as long as we inside structured control flow.
2586 }
2587 llvm_unreachable("Unknown BroadcastType");
2588}
2589
2590LogicalResult gpu::SubgroupBroadcastOp::verify() {
2591 switch (getBroadcastType()) {
2592 case BroadcastType::first_active_lane:
2593 if (getLane())
2594 return emitOpError()
2595 << "lane can only be specified for `specific_lane` broadcast";
2596 return success();
2597 case BroadcastType::specific_lane:
2598 if (!getLane())
2599 return emitOpError()
2600 << "lane must be specified for `specific_lane` broadcast";
2601 return success();
2602 }
2603 llvm_unreachable("Unknown BroadcastType");
2604}
2605
2606OpFoldResult gpu::SubgroupBroadcastOp::fold(FoldAdaptor /*adaptor*/) {
2607 // Broadcast result is always uniform.
2608 if (auto prev = getSrc().getDefiningOp<SubgroupBroadcastOp>())
2609 return prev.getResult();
2610
2611 return nullptr;
2612}
2613
2614//===----------------------------------------------------------------------===//
2615// GPU KernelMetadataAttr
2616//===----------------------------------------------------------------------===//
2617
2618KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2619 DictionaryAttr metadata) {
2620 assert(kernel && "invalid kernel");
2621 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2622 kernel.getAllArgAttrs(), metadata);
2623}
2624
2625KernelMetadataAttr
2626KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2627 FunctionOpInterface kernel,
2628 DictionaryAttr metadata) {
2629 assert(kernel && "invalid kernel");
2630 return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2631 kernel.getAllArgAttrs(), metadata);
2632}
2633
2634KernelMetadataAttr
2635KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2636 if (attrs.empty())
2637 return *this;
2638 NamedAttrList attrList;
2639 if (DictionaryAttr dict = getMetadata())
2640 attrList.append(dict);
2641 attrList.append(attrs);
2642 return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2643 attrList.getDictionary(getContext()));
2644}
2645
2646LogicalResult
2647KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2648 StringAttr name, Type functionType,
2649 ArrayAttr argAttrs, DictionaryAttr metadata) {
2650 if (name.empty())
2651 return emitError() << "the kernel name can't be empty";
2652 if (argAttrs) {
2653 if (llvm::any_of(argAttrs, [](Attribute attr) {
2654 return !llvm::isa<DictionaryAttr>(attr);
2655 }))
2656 return emitError()
2657 << "all attributes in the array must be a dictionary attribute";
2658 }
2659 return success();
2660}
2661
2662//===----------------------------------------------------------------------===//
2663// GPU KernelTableAttr
2664//===----------------------------------------------------------------------===//
2665
2666KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2667 ArrayRef<KernelMetadataAttr> kernels,
2668 bool isSorted) {
2669 // Note that `is_sorted` is always only invoked once even with assertions ON.
2670 assert((!isSorted || llvm::is_sorted(kernels)) &&
2671 "expected a sorted kernel array");
2672 // Immediately return the attribute if the array is sorted.
2673 if (isSorted || llvm::is_sorted(kernels))
2674 return Base::get(context, kernels);
2675 // Sort the array.
2676 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2677 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2678 return Base::get(context, kernelsTmp);
2679}
2680
2681KernelTableAttr KernelTableAttr::getChecked(
2682 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2683 ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2684 // Note that `is_sorted` is always only invoked once even with assertions ON.
2685 assert((!isSorted || llvm::is_sorted(kernels)) &&
2686 "expected a sorted kernel array");
2687 // Immediately return the attribute if the array is sorted.
2688 if (isSorted || llvm::is_sorted(kernels))
2689 return Base::getChecked(emitError, context, kernels);
2690 // Sort the array.
2691 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2692 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2693 return Base::getChecked(emitError, context, kernelsTmp);
2694}
2695
2696LogicalResult
2697KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2698 ArrayRef<KernelMetadataAttr> kernels) {
2699 if (kernels.size() < 2)
2700 return success();
2701 // Check that the kernels are uniquely named.
2702 if (std::adjacent_find(kernels.begin(), kernels.end(),
2703 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2704 return l.getName() == r.getName();
2705 }) != kernels.end()) {
2706 return emitError() << "expected all kernels to be uniquely named";
2707 }
2708 return success();
2709}
2710
2711KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2712 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2713 return found ? *iterator : KernelMetadataAttr();
2714}
2715
2716KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2717 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2718 return found ? *iterator : KernelMetadataAttr();
2719}
2720
2721//===----------------------------------------------------------------------===//
2722// GPU target options
2723//===----------------------------------------------------------------------===//
2724
2739
2757
2758TypeID TargetOptions::getTypeID() const { return typeID; }
2759
2760StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2761
2765
2766StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2767
2768StringRef TargetOptions::getELFSection() const { return elfSection; }
2769
2773
2774function_ref<void(llvm::Module &)>
2778
2779function_ref<void(llvm::Module &)>
2783
2784function_ref<void(llvm::Module &)>
2788
2790 return isaCallback;
2791}
2792
2793CompilationTarget TargetOptions::getCompilationTarget() const {
2794 return compilationTarget;
2795}
2796
2798 return CompilationTarget::Fatbin;
2799}
2800
2801std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2803 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2804 llvm::StringSaver stringSaver(options.first);
2805 StringRef opts = cmdOptions;
2806 // For a correct tokenization of the command line options `opts` must be
2807 // unquoted, otherwise the tokenization function returns a single string: the
2808 // unquoted `cmdOptions` -which is not the desired behavior.
2809 // Remove any quotes if they are at the beginning and end of the string:
2810 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2811 opts.consume_front("\""), opts.consume_back("\"");
2812 if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2813 opts.consume_front("'"), opts.consume_back("'");
2814#ifdef _WIN32
2815 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2816 /*MarkEOLs=*/false);
2817#else
2818 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2819 /*MarkEOLs=*/false);
2820#endif // _WIN32
2821 return options;
2822}
2823
2824std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2828
2829std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2831 size_t startPos = cmdOptions.find(startsWith);
2832 if (startPos == std::string::npos)
2833 return {llvm::BumpPtrAllocator(), SmallVector<const char *>()};
2834
2835 auto tokenized =
2836 tokenizeCmdOptions(cmdOptions.substr(startPos + startsWith.size()));
2837 cmdOptions.resize(startPos);
2838 return tokenized;
2839}
2840
2842
2843#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2844#include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2845
2846#define GET_ATTRDEF_CLASSES
2847#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2848
2849#define GET_OP_CLASSES
2850#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2851
2852#include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values, ArrayAttr attributes={})
static LogicalResult verifyDistributedType(Type expanded, Type distributed, int64_t warpSize, Operation *op)
Helper check if the distributed vector type is consistent with the expanded type and distributed size...
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices, StringRef keyword)
static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op, PatternRewriter &rewriter)
Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition TypeID.h:323
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
UnitAttr getUnitAttr()
Definition Builders.cpp:102
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:80
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition Builders.cpp:108
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:98
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:579
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:234
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:255
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:611
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:237
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool isParent() const
Returns true if branching from the parent op.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:78
bool isIndex() const
Definition Types.cpp:56
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
function_ref< void(llvm::Module &)> optimizedLlvmIRCallback
Callback invoked with LLVM IR for the device module after LLVM optimizations but before codegen.
function_ref< void(StringRef)> getISACallback() const
Returns the callback invoked with the target ISA for the device, for example PTX assembly.
TypeID getTypeID() const
Returns the typeID.
std::string toolkitPath
Path to the target toolkit.
SymbolTable * getSymbolTable() const
Returns the result of the getSymbolTableCallback callback or a nullptr if no callback was provided.
StringRef getELFSection() const
Returns the ELF section.
StringRef getCmdOptions() const
Returns the command line options.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
function_ref< void(StringRef)> isaCallback
Callback invoked with the target ISA for the device, for example PTX assembly.
CompilationTarget compilationTarget
Compilation process target format.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
function_ref< void(llvm::Module &)> initialLlvmIRCallback
Callback invoked with the initial LLVM IR for the device module.
ArrayRef< Attribute > getLibrariesToLink() const
Returns the LLVM libraries to link to.
TargetOptions(StringRef toolkitPath={}, ArrayRef< Attribute > librariesToLink={}, StringRef cmdOptions={}, StringRef elfSection={}, CompilationTarget compilationTarget=getDefaultCompilationTarget(), function_ref< SymbolTable *()> getSymbolTableCallback={}, function_ref< void(llvm::Module &)> initialLlvmIRCallback={}, function_ref< void(llvm::Module &)> linkedLlvmIRCallback={}, function_ref< void(llvm::Module &)> optimizedLlvmIRCallback={}, function_ref< void(StringRef)> isaCallback={})
Constructor initializing the toolkit path, the list of files to link to, extra command line options,...
function_ref< void(llvm::Module &)> getOptimizedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after LLVM optimizations but before c...
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
StringRef getToolkitPath() const
Returns the toolkit path.
SmallVector< Attribute > librariesToLink
List of files to link with the LLVM module.
function_ref< void(llvm::Module &)> linkedLlvmIRCallback
Callback invoked with LLVM IR for the device module after linking the device libraries.
function_ref< void(llvm::Module &)> getInitialLlvmIRCallback() const
Returns the callback invoked with the initial LLVM IR for the device module.
function_ref< SymbolTable *()> getSymbolTableCallback
Callback for obtaining the parent symbol table of all the GPU modules being serialized.
static CompilationTarget getDefaultCompilationTarget()
Returns the default compilation target: CompilationTarget::Fatbin.
function_ref< void(llvm::Module &)> getLinkedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after linking the device libraries.
std::string elfSection
ELF Section where the binary needs to be located.
CompilationTarget getCompilationTarget() const
Returns the compilation target.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
std::pair< IteratorT, bool > findAttrSorted(IteratorT first, IteratorT last, StringRef name)
Using llvm::lower_bound requires an extra string comparison to check whether the returned iterator po...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition GPUDialect.h:39