MLIR 22.0.0git
GPUDialect.cpp
Go to the documentation of this file.
1//===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the GPU kernel-related dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
19#include "mlir/IR/Attributes.h"
20#include "mlir/IR/Builders.h"
22#include "mlir/IR/BuiltinOps.h"
24#include "mlir/IR/Diagnostics.h"
26#include "mlir/IR/Matchers.h"
29#include "mlir/IR/SymbolTable.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/CommandLine.h"
38#include "llvm/Support/ErrorHandling.h"
39#include "llvm/Support/FormatVariadic.h"
40#include "llvm/Support/InterleavedRange.h"
41#include "llvm/Support/StringSaver.h"
42#include <cassert>
43#include <numeric>
44
45using namespace mlir;
46using namespace mlir::gpu;
47
48#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
49
50//===----------------------------------------------------------------------===//
51// GPU Device Mapping Attributes
52//===----------------------------------------------------------------------===//
53
54int64_t GPUBlockMappingAttr::getMappingId() const {
55 return static_cast<int64_t>(getBlock());
56}
57
58bool GPUBlockMappingAttr::isLinearMapping() const {
59 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
60}
61
62int64_t GPUBlockMappingAttr::getRelativeIndex() const {
63 return isLinearMapping()
64 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
65 : getMappingId();
66}
67
68int64_t GPUWarpgroupMappingAttr::getMappingId() const {
69 return static_cast<int64_t>(getWarpgroup());
70}
71
72bool GPUWarpgroupMappingAttr::isLinearMapping() const {
73 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
74}
75
76int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
77 return isLinearMapping()
78 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
79 : getMappingId();
80}
81
82int64_t GPUWarpMappingAttr::getMappingId() const {
83 return static_cast<int64_t>(getWarp());
84}
85
86bool GPUWarpMappingAttr::isLinearMapping() const {
87 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
88}
89
90int64_t GPUWarpMappingAttr::getRelativeIndex() const {
91 return isLinearMapping()
92 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
93 : getMappingId();
94}
95
96int64_t GPUThreadMappingAttr::getMappingId() const {
97 return static_cast<int64_t>(getThread());
98}
99
100bool GPUThreadMappingAttr::isLinearMapping() const {
101 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
102}
103
104int64_t GPUThreadMappingAttr::getRelativeIndex() const {
105 return isLinearMapping()
106 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
107 : getMappingId();
108}
109
110int64_t GPULaneMappingAttr::getMappingId() const {
111 return static_cast<int64_t>(getLane());
112}
113
114bool GPULaneMappingAttr::isLinearMapping() const {
115 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
116}
117
118int64_t GPULaneMappingAttr::getRelativeIndex() const {
119 return isLinearMapping()
120 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
121 : getMappingId();
122}
123
124int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds() const { return 64; }
125
126/// 8 4 0
127/// Example mask : 0 0 0 1 1 0 1 0 0
128///
129/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
130/// Logical id for e.g. 5 (2) constructs filter (1 << 5 - 1).
131///
132/// Example mask : 0 0 0 1 1 0 1 0 0
133/// Example filter: 0 0 0 0 1 1 1 1 1
134/// Intersection : 0 0 0 0 1 0 1 0 0
135/// PopCnt : 2
136Value GPUMappingMaskAttr::createLogicalLinearMappingId(
137 OpBuilder &b, Value physicalLinearMappingId) const {
138 Location loc = physicalLinearMappingId.getLoc();
139 Value mask =
140 arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
141 Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
142 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
143 filter = arith::SubIOp::create(b, loc, filter, one);
144 Value filteredId = arith::AndIOp::create(b, loc, mask, filter);
145 return math::CtPopOp::create(b, loc, filteredId);
146}
147
148/// 8 4 0
149/// Example mask : 0 0 0 1 1 0 1 0 0
150///
151/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
152/// Logical id for e.g. 5 (2) constructs filter (1 << 5).
153///
154/// Example mask : 0 0 0 1 1 0 1 0 0
155/// Example filter: 0 0 0 1 0 0 0 0 0
156/// Intersection : 0 0 0 1 0 0 0 0 0
157/// Cmp : 1
158Value GPUMappingMaskAttr::createIsActiveIdPredicate(
159 OpBuilder &b, Value physicalLinearMappingId) const {
160 Location loc = physicalLinearMappingId.getLoc();
161 Value mask =
162 arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
163 Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
164 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
165 Value filtered = arith::AndIOp::create(b, loc, mask, filter);
166 Value zero = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(0));
167 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ne, filtered,
168 zero);
169}
170
171int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
172 return static_cast<int64_t>(getAddressSpace());
173}
174
175bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
176 llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
177}
178
179int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
180 llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
181}
182
183//===----------------------------------------------------------------------===//
184// MMAMatrixType
185//===----------------------------------------------------------------------===//
186
188 StringRef operand) {
189 return Base::get(elementType.getContext(), shape, elementType, operand);
190}
191
194 ArrayRef<int64_t> shape, Type elementType,
195 StringRef operand) {
196 return Base::getChecked(emitError, elementType.getContext(), shape,
197 elementType, operand);
198}
199
200unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
201
203 return getImpl()->getShape();
204}
205
206Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
207
208StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
209
211 return elementType.isF16() || elementType.isF32() || elementType.isF64() ||
212 elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
213 elementType.isInteger(32);
214}
215
216LogicalResult
218 ArrayRef<int64_t> shape, Type elementType,
219 StringRef operand) {
220 if (operand != "AOp" && operand != "BOp" && operand != "COp")
221 return emitError() << "operand expected to be one of AOp, BOp or COp";
222
223 if (shape.size() != 2)
224 return emitError() << "MMAMatrixType must have exactly two dimensions";
225
226 if (!MMAMatrixType::isValidElementType(elementType))
227 return emitError()
228 << "MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
229
230 return success();
231}
232
233//===----------------------------------------------------------------------===//
234// GPUDialect
235//===----------------------------------------------------------------------===//
236
237bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
238 if (!memorySpace)
239 return false;
240 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
241 return gpuAttr.getValue() == getWorkgroupAddressSpace();
242 return false;
243}
244
245bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
246 Attribute memorySpace = type.getMemorySpace();
247 return isWorkgroupMemoryAddressSpace(memorySpace);
248}
249
250bool GPUDialect::isKernel(Operation *op) {
251 UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
252 return static_cast<bool>(isKernelAttr);
253}
254
255namespace {
256/// This class defines the interface for handling inlining with gpu
257/// operations.
258struct GPUInlinerInterface : public DialectInlinerInterface {
259 using DialectInlinerInterface::DialectInlinerInterface;
260
261 /// All gpu dialect ops can be inlined.
262 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
263 return true;
264 }
265};
266} // namespace
267
268void GPUDialect::initialize() {
269 addTypes<AsyncTokenType>();
270 addTypes<MMAMatrixType>();
271 addTypes<SparseDnTensorHandleType>();
272 addTypes<SparseSpMatHandleType>();
273 addTypes<SparseSpGEMMOpHandleType>();
274 addOperations<
275#define GET_OP_LIST
276#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
277 >();
278 addAttributes<
279#define GET_ATTRDEF_LIST
280#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
281 >();
282 addInterfaces<GPUInlinerInterface>();
283 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
284 TerminatorOp>();
285 declarePromisedInterfaces<
286 ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,
287 ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,
288 SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();
289}
290
291static std::string getSparseHandleKeyword(SparseHandleKind kind) {
292 switch (kind) {
294 return "sparse.dntensor_handle";
296 return "sparse.spmat_handle";
298 return "sparse.spgemmop_handle";
299 }
300 llvm_unreachable("unknown sparse handle kind");
301 return "";
302}
303
304Type GPUDialect::parseType(DialectAsmParser &parser) const {
305 // Parse the main keyword for the type.
306 StringRef keyword;
307 if (parser.parseKeyword(&keyword))
308 return Type();
309 MLIRContext *context = getContext();
310
311 // Handle 'async token' types.
312 if (keyword == "async.token")
313 return AsyncTokenType::get(context);
314
315 if (keyword == "mma_matrix") {
316 SMLoc beginLoc = parser.getNameLoc();
317
318 // Parse '<'.
319 if (parser.parseLess())
320 return nullptr;
321
322 // Parse the size and elementType.
323 SmallVector<int64_t> shape;
324 Type elementType;
325 if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
326 parser.parseType(elementType))
327 return nullptr;
328
329 // Parse ','
330 if (parser.parseComma())
331 return nullptr;
332
333 // Parse operand.
334 std::string operand;
335 if (failed(parser.parseOptionalString(&operand)))
336 return nullptr;
337
338 // Parse '>'.
339 if (parser.parseGreater())
340 return nullptr;
341
343 parser.getEncodedSourceLoc(beginLoc)),
344 shape, elementType, operand);
345 }
346
348 return SparseDnTensorHandleType::get(context);
350 return SparseSpMatHandleType::get(context);
352 return SparseSpGEMMOpHandleType::get(context);
353
354 parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
355 return Type();
356}
357// TODO: print refined type here. Notice that should be corresponding to the
358// parser
359void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
360 TypeSwitch<Type>(type)
361 .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
362 .Case<SparseDnTensorHandleType>([&](Type) {
364 })
365 .Case<SparseSpMatHandleType>(
367 .Case<SparseSpGEMMOpHandleType>([&](Type) {
369 })
370 .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
371 os << "mma_matrix<";
372 auto shape = fragTy.getShape();
373 for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
374 os << *dim << 'x';
375 os << shape.back() << 'x' << fragTy.getElementType();
376 os << ", \"" << fragTy.getOperand() << "\"" << '>';
377 })
378 .DefaultUnreachable("unexpected 'gpu' type kind");
379}
380
381static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
382 NamedAttribute attr) {
383 auto array = dyn_cast<DenseI32ArrayAttr>(attr.getValue());
384 if (!array)
385 return op->emitOpError(Twine(attr.getName()) +
386 " must be a dense i32 array");
387 if (array.size() != 3)
388 return op->emitOpError(Twine(attr.getName()) +
389 " must contain exactly 3 elements");
390 return success();
391}
392
393LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
394 NamedAttribute attr) {
395 if (attr.getName() == getKnownBlockSizeAttrHelper().getName())
396 return verifyKnownLaunchSizeAttr(op, attr);
397 if (attr.getName() == getKnownGridSizeAttrHelper().getName())
398 return verifyKnownLaunchSizeAttr(op, attr);
399 if (attr.getName() == getKnownClusterSizeAttrHelper().getName())
400 return verifyKnownLaunchSizeAttr(op, attr);
401 if (!llvm::isa<UnitAttr>(attr.getValue()) ||
402 attr.getName() != getContainerModuleAttrName())
403 return success();
404
405 auto module = dyn_cast<ModuleOp>(op);
406 if (!module)
407 return op->emitError("expected '")
408 << getContainerModuleAttrName() << "' attribute to be attached to '"
409 << ModuleOp::getOperationName() << '\'';
410
411 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
412 // Ignore launches that are nested more or less deep than functions in the
413 // module we are currently checking.
414 if (!launchOp->getParentOp() ||
415 launchOp->getParentOp()->getParentOp() != module)
416 return success();
417
418 // Ignore launch ops with missing attributes here. The errors will be
419 // reported by the verifiers of those ops.
420 if (!launchOp->getAttrOfType<SymbolRefAttr>(
421 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
422 return success();
423
424 // Check that `launch_func` refers to a well-formed GPU kernel container.
425 StringAttr kernelContainerName = launchOp.getKernelModuleName();
426 Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
427 if (!kernelContainer)
428 return launchOp.emitOpError()
429 << "kernel container '" << kernelContainerName.getValue()
430 << "' is undefined";
431
432 // If the container is a GPU binary op return success.
433 if (isa<BinaryOp>(kernelContainer))
434 return success();
435
436 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
437 if (!kernelModule)
438 return launchOp.emitOpError()
439 << "kernel module '" << kernelContainerName.getValue()
440 << "' is undefined";
441
442 // Check that `launch_func` refers to a well-formed kernel function.
443 Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
444 if (!kernelFunc)
445 return launchOp.emitOpError("kernel function '")
446 << launchOp.getKernel() << "' is undefined";
447 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
448 if (!kernelConvertedFunction) {
449 InFlightDiagnostic diag = launchOp.emitOpError()
450 << "referenced kernel '" << launchOp.getKernel()
451 << "' is not a function";
452 diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
453 return diag;
454 }
455
456 if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
457 GPUDialect::getKernelFuncAttrName()))
458 return launchOp.emitOpError("kernel function is missing the '")
459 << GPUDialect::getKernelFuncAttrName() << "' attribute";
460
461 // TODO: If the kernel isn't a GPU function (which happens during separate
462 // compilation), do not check type correspondence as it would require the
463 // verifier to be aware of the type conversion.
464 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
465 if (!kernelGPUFunction)
466 return success();
467
468 unsigned actualNumArguments = launchOp.getNumKernelOperands();
469 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
470 if (expectedNumArguments != actualNumArguments)
471 return launchOp.emitOpError("got ")
472 << actualNumArguments << " kernel operands but expected "
473 << expectedNumArguments;
474
475 auto functionType = kernelGPUFunction.getFunctionType();
476 for (unsigned i = 0; i < expectedNumArguments; ++i) {
477 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
478 return launchOp.emitOpError("type of function argument ")
479 << i << " does not match";
480 }
481 }
482
483 return success();
484 });
485
486 return walkResult.wasInterrupted() ? failure() : success();
487}
488
489/// Parses an optional list of async operands with an optional leading keyword.
490/// (`async`)? (`[` ssa-id-list `]`)?
491///
492/// This method is used by the tablegen assembly format for async ops as well.
493static ParseResult parseAsyncDependencies(
494 OpAsmParser &parser, Type &asyncTokenType,
496 auto loc = parser.getCurrentLocation();
497 if (succeeded(parser.parseOptionalKeyword("async"))) {
498 if (parser.getNumResults() == 0)
499 return parser.emitError(loc, "needs to be named when marked 'async'");
500 asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
501 }
502 return parser.parseOperandList(asyncDependencies,
504}
505
506/// Prints optional async dependencies with its leading keyword.
507/// (`async`)? (`[` ssa-id-list `]`)?
508// Used by the tablegen assembly format for several async ops.
510 Type asyncTokenType,
511 OperandRange asyncDependencies) {
512 if (asyncTokenType)
513 printer << "async";
514 if (asyncDependencies.empty())
515 return;
516 if (asyncTokenType)
517 printer << ' ';
518 printer << llvm::interleaved_array(asyncDependencies);
519}
520
521// GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
522/// Parses a GPU function memory attribution.
523///
524/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
525/// (`private` `(` ssa-id-and-type-list `)`)?
526///
527/// Note that this function parses only one of the two similar parts, with the
528/// keyword provided as argument.
529static ParseResult
530parseAttributions(OpAsmParser &parser, StringRef keyword,
532 // If we could not parse the keyword, just assume empty list and succeed.
533 if (failed(parser.parseOptionalKeyword(keyword)))
534 return success();
535
537 /*allowType=*/true);
538}
539
540static void printAttributions(OpAsmPrinter &p, StringRef keyword,
542 ArrayAttr attributes = {}) {
543 if (values.empty())
544 return;
545
546 p << ' ' << keyword << '(';
547 llvm::interleaveComma(
548 llvm::enumerate(values), p, [&p, attributes](auto pair) {
549 BlockArgument v = pair.value();
550 p << v << " : " << v.getType();
551
552 size_t attributionIndex = pair.index();
553 DictionaryAttr attrs;
554 if (attributes && attributionIndex < attributes.size())
555 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
556 if (attrs)
557 p.printOptionalAttrDict(attrs.getValue());
558 });
559 p << ')';
560}
561
562/// Verifies a GPU function memory attribution.
563static LogicalResult verifyAttributions(Operation *op,
564 ArrayRef<BlockArgument> attributions,
565 gpu::AddressSpace memorySpace) {
566 for (Value v : attributions) {
567 auto type = llvm::dyn_cast<MemRefType>(v.getType());
568 if (!type)
569 return op->emitOpError() << "expected memref type in attribution";
570
571 // We can only verify the address space if it hasn't already been lowered
572 // from the AddressSpaceAttr to a target-specific numeric value.
573 auto addressSpace =
574 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
575 if (!addressSpace)
576 continue;
577 if (addressSpace.getValue() != memorySpace)
578 return op->emitOpError()
579 << "expected memory space " << stringifyAddressSpace(memorySpace)
580 << " in attribution";
581 }
582 return success();
583}
584
585//===----------------------------------------------------------------------===//
586// AllReduceOp
587//===----------------------------------------------------------------------===//
588
589static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
590 Type resType) {
591 using Kind = gpu::AllReduceOperation;
592 if (llvm::is_contained(
593 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
594 opName)) {
595 if (!isa<FloatType>(resType))
596 return failure();
597 }
598
599 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
600 Kind::AND, Kind::OR, Kind::XOR},
601 opName)) {
602 if (!isa<IntegerType>(resType))
603 return failure();
604 }
605
606 return success();
607}
608
609LogicalResult gpu::AllReduceOp::verifyRegions() {
610 if (getBody().empty() != getOp().has_value())
611 return emitError("expected either an op attribute or a non-empty body");
612 if (!getBody().empty()) {
613 if (getBody().getNumArguments() != 2)
614 return emitError("expected two region arguments");
615 for (auto argument : getBody().getArguments()) {
616 if (argument.getType() != getType())
617 return emitError("incorrect region argument type");
618 }
619 unsigned yieldCount = 0;
620 for (Block &block : getBody()) {
621 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
622 if (yield.getNumOperands() != 1)
623 return emitError("expected one gpu.yield operand");
624 if (yield.getOperand(0).getType() != getType())
625 return emitError("incorrect gpu.yield type");
626 ++yieldCount;
627 }
628 }
629 if (yieldCount == 0)
630 return emitError("expected gpu.yield op in region");
631 } else {
632 gpu::AllReduceOperation opName = *getOp();
633 if (failed(verifyReduceOpAndType(opName, getType()))) {
634 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
635 << "` reduction operation is not compatible with type "
636 << getType();
637 }
638 }
639
640 return success();
641}
642
644 auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
645 if (!launchOp)
646 return false;
647
648 Region &body = launchOp.getBody();
649 assert(!body.empty() && "Invalid region");
650
651 // Only convert ops in gpu::launch entry block for now.
652 return op->getBlock() == &body.front();
653}
654
655OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
656 if (!getUniform() && canMakeGroupOpUniform(*this)) {
657 setUniform(true);
658 return getResult();
659 }
660
661 return nullptr;
662}
663
664// TODO: Support optional custom attributes (without dialect prefix).
665static ParseResult parseAllReduceOperation(AsmParser &parser,
666 AllReduceOperationAttr &attr) {
667 StringRef enumStr;
668 if (!parser.parseOptionalKeyword(&enumStr)) {
669 std::optional<AllReduceOperation> op =
670 gpu::symbolizeAllReduceOperation(enumStr);
671 if (!op)
672 return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
673 attr = AllReduceOperationAttr::get(parser.getContext(), *op);
674 }
675 return success();
676}
677
679 AllReduceOperationAttr attr) {
680 if (attr)
681 attr.print(printer);
682}
683
684//===----------------------------------------------------------------------===//
685// SubgroupReduceOp
686//===----------------------------------------------------------------------===//
687
688LogicalResult gpu::SubgroupReduceOp::verify() {
689 Type elemType = getType();
690 if (auto vecTy = dyn_cast<VectorType>(elemType)) {
691 if (vecTy.isScalable())
692 return emitOpError() << "is not compatible with scalable vector types";
693
694 elemType = vecTy.getElementType();
695 }
696
697 gpu::AllReduceOperation opName = getOp();
698 if (failed(verifyReduceOpAndType(opName, elemType))) {
699 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
700 << "` reduction operation is not compatible with type "
701 << getType();
702 }
703
704 auto clusterSize = getClusterSize();
705 if (clusterSize) {
706 uint32_t size = *clusterSize;
707 if (!llvm::isPowerOf2_32(size)) {
708 return emitOpError() << "cluster size " << size
709 << " is not a power of two";
710 }
711 }
712
713 uint32_t stride = getClusterStride();
714 if (stride != 1 && !clusterSize) {
715 return emitOpError() << "cluster stride can only be specified if cluster "
716 "size is specified";
717 }
718 if (!llvm::isPowerOf2_32(stride)) {
719 return emitOpError() << "cluster stride " << stride
720 << " is not a power of two";
721 }
722
723 return success();
724}
725
726OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
727 if (getClusterSize() == 1)
728 return getValue();
729
730 if (!getUniform() && canMakeGroupOpUniform(*this)) {
731 setUniform(true);
732 return getResult();
733 }
734
735 return nullptr;
736}
737
738//===----------------------------------------------------------------------===//
739// AsyncOpInterface
740//===----------------------------------------------------------------------===//
741
743 op->insertOperands(0, {token});
744 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
745 return;
746 auto attrName =
748 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
749
750 // Async dependencies is the only variadic operand.
751 if (!sizeAttr)
752 return;
753
754 SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
755 ++sizes.front();
756 op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
757}
758
759//===----------------------------------------------------------------------===//
760// LaunchOp
761//===----------------------------------------------------------------------===//
762
763void LaunchOp::build(OpBuilder &builder, OperationState &result,
764 Value gridSizeX, Value gridSizeY, Value gridSizeZ,
765 Value getBlockSizeX, Value getBlockSizeY,
766 Value getBlockSizeZ, Value dynamicSharedMemorySize,
767 Type asyncTokenType, ValueRange asyncDependencies,
768 TypeRange workgroupAttributions,
769 TypeRange privateAttributions, Value clusterSizeX,
770 Value clusterSizeY, Value clusterSizeZ,
771 FlatSymbolRefAttr module, FlatSymbolRefAttr function) {
772 OpBuilder::InsertionGuard g(builder);
773
774 // Add a WorkGroup attribution attribute. This attribute is required to
775 // identify private attributions in the list of block argguments.
776 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
777 builder.getI64IntegerAttr(workgroupAttributions.size()));
778
779 // Add Op operands.
780 result.addOperands(asyncDependencies);
781 if (asyncTokenType)
782 result.types.push_back(builder.getType<AsyncTokenType>());
783
784 // Add grid and block sizes as op operands, followed by the data operands.
785 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
786 getBlockSizeY, getBlockSizeZ});
787 if (clusterSizeX)
788 result.addOperands(clusterSizeX);
789 if (clusterSizeY)
790 result.addOperands(clusterSizeY);
791 if (clusterSizeZ)
792 result.addOperands(clusterSizeZ);
793 if (dynamicSharedMemorySize)
794 result.addOperands(dynamicSharedMemorySize);
795
796 // Add optional module and function attributes.
797 if (module)
798 result.addAttribute(getModuleAttrName(result.name), module);
799 if (function)
800 result.addAttribute(getFunctionAttrName(result.name), function);
801
802 // Create a kernel body region with kNumConfigRegionAttributes + N memory
803 // attributions, where the first kNumConfigRegionAttributes arguments have
804 // `index` type and the rest have the same types as the data operands.
805 Region *kernelRegion = result.addRegion();
806 Block *body = builder.createBlock(kernelRegion);
807 // TODO: Allow passing in proper locations here.
808 for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
809 body->addArgument(builder.getIndexType(), result.location);
810 // Add WorkGroup & Private attributions to the region arguments.
811 for (Type argTy : workgroupAttributions)
812 body->addArgument(argTy, result.location);
813 for (Type argTy : privateAttributions)
814 body->addArgument(argTy, result.location);
815 // Fill OperandSegmentSize Attribute.
816 SmallVector<int32_t, 11> segmentSizes(11, 1);
817 segmentSizes.front() = asyncDependencies.size();
818 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
819 segmentSizes[7] = clusterSizeX ? 1 : 0;
820 segmentSizes[8] = clusterSizeY ? 1 : 0;
821 segmentSizes[9] = clusterSizeZ ? 1 : 0;
822 result.addAttribute(getOperandSegmentSizeAttr(),
823 builder.getDenseI32ArrayAttr(segmentSizes));
824}
825
826KernelDim3 LaunchOp::getBlockIds() {
827 assert(!getBody().empty() && "LaunchOp body must not be empty.");
828 auto args = getBody().getArguments();
829 return KernelDim3{args[0], args[1], args[2]};
830}
831
832KernelDim3 LaunchOp::getThreadIds() {
833 assert(!getBody().empty() && "LaunchOp body must not be empty.");
834 auto args = getBody().getArguments();
835 return KernelDim3{args[3], args[4], args[5]};
836}
837
838KernelDim3 LaunchOp::getGridSize() {
839 assert(!getBody().empty() && "LaunchOp body must not be empty.");
840 auto args = getBody().getArguments();
841 return KernelDim3{args[6], args[7], args[8]};
842}
843
844KernelDim3 LaunchOp::getBlockSize() {
845 assert(!getBody().empty() && "LaunchOp body must not be empty.");
846 auto args = getBody().getArguments();
847 return KernelDim3{args[9], args[10], args[11]};
848}
849
850std::optional<KernelDim3> LaunchOp::getClusterIds() {
851 assert(!getBody().empty() && "LaunchOp body must not be empty.");
852 if (!hasClusterSize())
853 return std::nullopt;
854 auto args = getBody().getArguments();
855 return KernelDim3{args[12], args[13], args[14]};
856}
857
858std::optional<KernelDim3> LaunchOp::getClusterSize() {
859 assert(!getBody().empty() && "LaunchOp body must not be empty.");
860 if (!hasClusterSize())
861 return std::nullopt;
862 auto args = getBody().getArguments();
863 return KernelDim3{args[15], args[16], args[17]};
864}
865
866KernelDim3 LaunchOp::getGridSizeOperandValues() {
867 auto operands = getOperands().drop_front(getAsyncDependencies().size());
868 return KernelDim3{operands[0], operands[1], operands[2]};
869}
870
871KernelDim3 LaunchOp::getBlockSizeOperandValues() {
872 auto operands = getOperands().drop_front(getAsyncDependencies().size());
873 return KernelDim3{operands[3], operands[4], operands[5]};
874}
875
876std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
877 auto operands = getOperands().drop_front(getAsyncDependencies().size());
878 if (!hasClusterSize())
879 return std::nullopt;
880 return KernelDim3{operands[6], operands[7], operands[8]};
881}
882
883LogicalResult LaunchOp::verify() {
884 if (!(hasClusterSize()) &&
885 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
886 return emitOpError() << "cluster size must be all present";
887 return success();
888}
889
890LogicalResult LaunchOp::verifyRegions() {
891 // Kernel launch takes kNumConfigOperands leading operands for grid/block
892 // sizes and transforms them into kNumConfigRegionAttributes region arguments
893 // for block/thread identifiers and grid/block sizes.
894 if (!getBody().empty()) {
895 if (getBody().getNumArguments() <
896 kNumConfigRegionAttributes + getNumWorkgroupAttributions())
897 return emitOpError("unexpected number of region arguments");
898 }
899
900 // Verify Attributions Address Spaces.
901 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
902 GPUDialect::getWorkgroupAddressSpace())) ||
903 failed(verifyAttributions(getOperation(), getPrivateAttributions(),
904 GPUDialect::getPrivateAddressSpace())))
905 return failure();
906
907 // Block terminators without successors are expected to exit the kernel region
908 // and must be `gpu.terminator`.
909 for (Block &block : getBody()) {
910 if (block.empty())
911 continue;
912 if (block.back().getNumSuccessors() != 0)
913 continue;
914 if (!isa<gpu::TerminatorOp>(&block.back())) {
915 return block.back()
916 .emitError()
917 .append("expected '", gpu::TerminatorOp::getOperationName(),
918 "' or a terminator with successors")
919 .attachNote(getLoc())
920 .append("in '", LaunchOp::getOperationName(), "' body region");
921 }
922 }
923
924 if (getNumResults() == 0 && getAsyncToken())
925 return emitOpError("needs to be named when async keyword is specified");
926
927 return success();
928}
929
930// Pretty-print the kernel grid/block size assignment as
931// (%iter-x, %iter-y, %iter-z) in
932// (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
933// where %size-* and %iter-* will correspond to the body region arguments.
935 KernelDim3 operands, KernelDim3 ids) {
936 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
937 p << size.x << " = " << operands.x << ", ";
938 p << size.y << " = " << operands.y << ", ";
939 p << size.z << " = " << operands.z << ')';
940}
941
942void LaunchOp::print(OpAsmPrinter &p) {
943 if (getAsyncToken()) {
944 p << " async";
945 if (!getAsyncDependencies().empty())
946 p << " [" << getAsyncDependencies() << ']';
947 }
948 // Print the launch configuration.
949 if (hasClusterSize()) {
950 p << ' ' << getClustersKeyword();
951 printSizeAssignment(p, getClusterSize().value(),
952 getClusterSizeOperandValues().value(),
953 getClusterIds().value());
954 }
955 p << ' ' << getBlocksKeyword();
956 printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
957 getBlockIds());
958 p << ' ' << getThreadsKeyword();
959 printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
960 getThreadIds());
961 if (getDynamicSharedMemorySize())
962 p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
963 << getDynamicSharedMemorySize();
964
965 // Print optional module attribute.
966 StringRef moduleAttrName = getModuleAttrName();
967 if (auto module = getModule()) {
968 p << ' ' << moduleAttrName << '(';
969 p.printSymbolName(*module);
970 p << ')';
971 }
972 // Print optional function attribute.
973 StringRef functionAttrName = getFunctionAttrName();
974 if (auto function = getFunction()) {
975 p << ' ' << functionAttrName << '(';
976 p.printSymbolName(*function);
977 p << ')';
978 }
979
980 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
981 printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
982
983 p << ' ';
984
985 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
986 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
987 LaunchOp::getOperandSegmentSizeAttr(),
988 getNumWorkgroupAttributionsAttrName(),
989 moduleAttrName, functionAttrName});
990}
991
992// Parse the size assignment blocks for blocks and threads. These have the form
993// (%region_arg, %region_arg, %region_arg) in
994// (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
995// where %region_arg are percent-identifiers for the region arguments to be
996// introduced further (SSA defs), and %operand are percent-identifiers for the
997// SSA value uses.
998static ParseResult
1003 assert(indices.size() == 3 && "space for three indices expected");
1006 /*allowResultNumber=*/false) ||
1007 parser.parseKeyword("in") || parser.parseLParen())
1008 return failure();
1009 std::move(args.begin(), args.end(), indices.begin());
1010
1011 for (int i = 0; i < 3; ++i) {
1012 if (i != 0 && parser.parseComma())
1013 return failure();
1014 if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
1015 parser.parseEqual() || parser.parseOperand(sizes[i]))
1016 return failure();
1017 }
1018
1019 return parser.parseRParen();
1020}
1021
1022/// Parses a Launch operation.
1023/// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
1024/// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
1025/// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
1026/// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
1027/// (`dynamic_shared_memory_size` ssa-use)?
1028/// (`module(` symbol-ref-id `)`)?
1029/// (`function(` symbol-ref-id `)`)?
1030/// memory-attribution
1031/// region attr-dict?
1032/// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
1033ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
1034 // Sizes of the grid and block.
1035 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
1036 sizes(LaunchOp::kNumConfigOperands);
1037
1038 // Region arguments to be created.
1039 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
1040 LaunchOp::kNumConfigRegionAttributes);
1041
1042 // Parse optional async dependencies.
1043 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
1044 Type asyncTokenType;
1045 if (failed(
1046 parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
1047 parser.resolveOperands(asyncDependencies, asyncTokenType,
1048 result.operands))
1049 return failure();
1050 if (parser.getNumResults() > 0)
1051 result.types.push_back(asyncTokenType);
1052
1053 bool hasCluster = false;
1054 if (succeeded(
1055 parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) {
1056 hasCluster = true;
1057 sizes.resize(9);
1058 regionArgs.resize(18);
1059 }
1060 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
1061 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
1062
1063 // Last three segment assigns the cluster size. In the region argument
1064 // list, this is last 6 arguments.
1065 if (hasCluster) {
1066 if (parseSizeAssignment(parser, sizesRef.drop_front(6),
1067 regionArgsRef.slice(15, 3),
1068 regionArgsRef.slice(12, 3)))
1069 return failure();
1070 }
1071 // Parse the size assignment segments: the first segment assigns grid sizes
1072 // and defines values for block identifiers; the second segment assigns block
1073 // sizes and defines values for thread identifiers. In the region argument
1074 // list, identifiers precede sizes, and block-related values precede
1075 // thread-related values.
1076 if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
1077 parseSizeAssignment(parser, sizesRef.take_front(3),
1078 regionArgsRef.slice(6, 3),
1079 regionArgsRef.slice(0, 3)) ||
1080 parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
1081 parseSizeAssignment(parser, sizesRef.drop_front(3),
1082 regionArgsRef.slice(9, 3),
1083 regionArgsRef.slice(3, 3)) ||
1084 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
1085 result.operands))
1086 return failure();
1087
1088 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
1089 bool hasDynamicSharedMemorySize = false;
1090 if (!parser.parseOptionalKeyword(
1091 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1092 hasDynamicSharedMemorySize = true;
1093 if (parser.parseOperand(dynamicSharedMemorySize) ||
1094 parser.resolveOperand(dynamicSharedMemorySize,
1095 parser.getBuilder().getI32Type(),
1096 result.operands))
1097 return failure();
1098 }
1099
1100 // Parse optional module attribute.
1101 StringRef moduleAttrName = getModuleAttrName(result.name);
1102 if (succeeded(parser.parseOptionalKeyword(moduleAttrName))) {
1103 FlatSymbolRefAttr moduleSymbol;
1104 if (parser.parseLParen() ||
1105 parser.parseAttribute(moduleSymbol, Type(), moduleAttrName,
1106 result.attributes) ||
1107 parser.parseRParen())
1108 return failure();
1109 }
1110 // Parse optional function attribute.
1111 StringRef functionAttrName = getFunctionAttrName(result.name);
1112 if (succeeded(parser.parseOptionalKeyword(functionAttrName))) {
1113 FlatSymbolRefAttr funcSymbol;
1114 if (parser.parseLParen() ||
1115 parser.parseAttribute(funcSymbol, Type(), functionAttrName,
1116 result.attributes) ||
1117 parser.parseRParen())
1118 return failure();
1119 }
1120
1121 // Create the region arguments, it has kNumConfigRegionAttributes arguments
1122 // that correspond to block/thread identifiers and grid/block sizes, all
1123 // having `index` type, a variadic number of WorkGroup Attributions and
1124 // a variadic number of Private Attributions. The number of WorkGroup
1125 // Attributions is stored in the attr with name:
1126 // LaunchOp::getNumWorkgroupAttributionsAttrName().
1127 Type index = parser.getBuilder().getIndexType();
1128 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
1129 LaunchOp::kNumConfigRegionAttributes + 6, index);
1130
1131 SmallVector<OpAsmParser::Argument> regionArguments;
1132 for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1133 OpAsmParser::Argument arg;
1134 arg.ssaName = std::get<0>(ssaValueAndType);
1135 arg.type = std::get<1>(ssaValueAndType);
1136 regionArguments.push_back(arg);
1137 }
1138
1139 Builder &builder = parser.getBuilder();
1140 // Parse workgroup memory attributions.
1141 if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
1142 regionArguments)))
1143 return failure();
1144
1145 // Store the number of operands we just parsed as the number of workgroup
1146 // memory attributions.
1147 unsigned numWorkgroupAttrs = regionArguments.size() -
1148 LaunchOp::kNumConfigRegionAttributes -
1149 (hasCluster ? 6 : 0);
1150 result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1151 builder.getI64IntegerAttr(numWorkgroupAttrs));
1152
1153 // Parse private memory attributions.
1154 if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
1155 regionArguments)))
1156 return failure();
1157
1158 // Introduce the body region and parse it. The region has
1159 // kNumConfigRegionAttributes arguments that correspond to
1160 // block/thread identifiers and grid/block sizes, all having `index` type.
1161 Region *body = result.addRegion();
1162 if (parser.parseRegion(*body, regionArguments) ||
1163 parser.parseOptionalAttrDict(result.attributes))
1164 return failure();
1165
1166 SmallVector<int32_t, 11> segmentSizes(11, 1);
1167 segmentSizes.front() = asyncDependencies.size();
1168
1169 if (!hasCluster) {
1170 segmentSizes[7] = 0;
1171 segmentSizes[8] = 0;
1172 segmentSizes[9] = 0;
1173 }
1174 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1175 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1176 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1177 return success();
1178}
1179
1180/// Simplify the gpu.launch when the range of a thread or block ID is
1181/// trivially known to be one.
1182struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1183 using OpRewritePattern<LaunchOp>::OpRewritePattern;
1184 LogicalResult matchAndRewrite(LaunchOp op,
1185 PatternRewriter &rewriter) const override {
1186 // If the range implies a single value for `id`, replace `id`'s uses by
1187 // zero.
1188 Value zero;
1189 bool simplified = false;
1190 auto constPropIdUses = [&](Value id, Value size) {
1191 // Check if size is trivially one.
1192 if (!matchPattern(size, m_One()))
1193 return;
1194 if (id.getUses().empty())
1195 return;
1196 if (!simplified) {
1197 // Create a zero value the first time.
1198 OpBuilder::InsertionGuard guard(rewriter);
1199 rewriter.setInsertionPointToStart(&op.getBody().front());
1200 zero =
1201 arith::ConstantIndexOp::create(rewriter, op.getLoc(), /*value=*/0);
1202 }
1203 rewriter.replaceAllUsesWith(id, zero);
1204 simplified = true;
1205 };
1206 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1207 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1208 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1209 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1210 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1211 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1212
1213 return success(simplified);
1214 }
1215};
1216
1217void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1218 MLIRContext *context) {
1219 rewrites.add<FoldLaunchArguments>(context);
1220}
1221
1222/// Adds a new block argument that corresponds to buffers located in
1223/// workgroup memory.
1224BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1225 auto attrName = getNumWorkgroupAttributionsAttrName();
1226 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1227 (*this)->setAttr(attrName,
1228 IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1229 return getBody().insertArgument(
1230 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1231}
1232
1233/// Adds a new block argument that corresponds to buffers located in
1234/// private memory.
1235BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1236 // Buffers on the private memory always come after buffers on the workgroup
1237 // memory.
1238 return getBody().addArgument(type, loc);
1239}
1240
1241//===----------------------------------------------------------------------===//
1242// LaunchFuncOp
1243//===----------------------------------------------------------------------===//
1244
1245void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1246 SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1247 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1248 ValueRange kernelOperands, Type asyncTokenType,
1249 ValueRange asyncDependencies,
1250 std::optional<KernelDim3> clusterSize) {
1251 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1252 "expected a symbol reference with a single nested reference");
1253 result.addOperands(asyncDependencies);
1254 if (asyncTokenType)
1255 result.types.push_back(builder.getType<AsyncTokenType>());
1256
1257 // Add grid and block sizes as op operands, followed by the data operands.
1258 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1260 if (clusterSize.has_value())
1261 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1262 if (dynamicSharedMemorySize)
1263 result.addOperands(dynamicSharedMemorySize);
1264 result.addOperands(kernelOperands);
1265
1266 Properties &prop = result.getOrAddProperties<Properties>();
1267 prop.kernel = kernelSymbol;
1268 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1269 // Initialize the segment sizes to 1.
1270 llvm::fill(prop.operandSegmentSizes, 1);
1271 prop.operandSegmentSizes[0] = asyncDependencies.size();
1272 if (!clusterSize.has_value()) {
1273 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1274 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1275 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1276 }
1277 prop.operandSegmentSizes[segmentSizesLen - 3] =
1278 dynamicSharedMemorySize ? 1 : 0;
1279 prop.operandSegmentSizes[segmentSizesLen - 2] =
1280 static_cast<int32_t>(kernelOperands.size());
1281 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1282}
1283
1284void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1285 GPUFuncOp kernelFunc, KernelDim3 gridSize,
1286 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1287 ValueRange kernelOperands, Type asyncTokenType,
1288 ValueRange asyncDependencies,
1289 std::optional<KernelDim3> clusterSize) {
1290 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1291 auto kernelSymbol =
1292 SymbolRefAttr::get(kernelModule.getNameAttr(),
1293 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1294 build(builder, result, kernelSymbol, gridSize, getBlockSize,
1295 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1296 asyncDependencies, clusterSize);
1297}
1298
1299void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1300 SymbolRefAttr kernel, KernelDim3 gridSize,
1301 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1302 ValueRange kernelOperands, Value asyncObject,
1303 std::optional<KernelDim3> clusterSize) {
1304 // Add grid and block sizes as op operands, followed by the data operands.
1305 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1307 if (clusterSize.has_value())
1308 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1309 if (dynamicSharedMemorySize)
1310 result.addOperands(dynamicSharedMemorySize);
1311 result.addOperands(kernelOperands);
1312 if (asyncObject)
1313 result.addOperands(asyncObject);
1314 Properties &prop = result.getOrAddProperties<Properties>();
1315 prop.kernel = kernel;
1316 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1317 // Initialize the segment sizes to 1.
1318 llvm::fill(prop.operandSegmentSizes, 1);
1319 prop.operandSegmentSizes[0] = 0;
1320 if (!clusterSize.has_value()) {
1321 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1322 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1323 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1324 }
1325 prop.operandSegmentSizes[segmentSizesLen - 3] =
1326 dynamicSharedMemorySize ? 1 : 0;
1327 prop.operandSegmentSizes[segmentSizesLen - 2] =
1328 static_cast<int32_t>(kernelOperands.size());
1329 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1330}
1331
1332StringAttr LaunchFuncOp::getKernelModuleName() {
1333 return getKernel().getRootReference();
1334}
1335
1336StringAttr LaunchFuncOp::getKernelName() {
1337 return getKernel().getLeafReference();
1338}
1339
1340unsigned LaunchFuncOp::getNumKernelOperands() {
1341 return getKernelOperands().size();
1342}
1343
1344Value LaunchFuncOp::getKernelOperand(unsigned i) {
1345 return getKernelOperands()[i];
1346}
1347
1348KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1349 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1350 return KernelDim3{operands[0], operands[1], operands[2]};
1351}
1352
1353KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1354 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1355 return KernelDim3{operands[3], operands[4], operands[5]};
1356}
1357
1358KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1359 assert(hasClusterSize() &&
1360 "cluster size is not set, check hasClusterSize() first");
1361 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1362 return KernelDim3{operands[6], operands[7], operands[8]};
1363}
1364
1365LogicalResult LaunchFuncOp::verify() {
1366 auto module = (*this)->getParentOfType<ModuleOp>();
1367 if (!module)
1368 return emitOpError("expected to belong to a module");
1369
1370 if (!module->getAttrOfType<UnitAttr>(
1371 GPUDialect::getContainerModuleAttrName()))
1372 return emitOpError("expected the closest surrounding module to have the '" +
1373 GPUDialect::getContainerModuleAttrName() +
1374 "' attribute");
1375
1376 if (hasClusterSize()) {
1377 if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1378 getClusterSizeZ().getType() != getClusterSizeX().getType())
1379 return emitOpError()
1380 << "expects types of the cluster dimensions must be the same";
1381 }
1382
1383 return success();
1384}
1385
1386static ParseResult
1388 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1389 Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1390 if (succeeded(parser.parseOptionalColon())) {
1391 if (parser.parseType(dimTy))
1392 return failure();
1393 } else {
1394 dimTy = IndexType::get(parser.getContext());
1395 }
1396 if (clusterValue.has_value()) {
1397 clusterXTy = clusterYTy = clusterZTy = dimTy;
1398 }
1399 return success();
1400}
1401
1402static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1403 Value clusterValue, Type clusterXTy,
1404 Type clusterYTy, Type clusterZTy) {
1405 if (!dimTy.isIndex())
1406 printer << ": " << dimTy;
1407}
1408
1409static ParseResult parseLaunchFuncOperands(
1410 OpAsmParser &parser,
1412 SmallVectorImpl<Type> &argTypes) {
1413 if (parser.parseOptionalKeyword("args"))
1414 return success();
1415
1416 auto parseElement = [&]() -> ParseResult {
1417 return failure(parser.parseOperand(argNames.emplace_back()) ||
1418 parser.parseColonType(argTypes.emplace_back()));
1419 };
1420
1422 parseElement, " in argument list");
1423}
1424
1426 OperandRange operands, TypeRange types) {
1427 if (operands.empty())
1428 return;
1429 printer << "args(";
1430 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1431 [&](const auto &pair) {
1432 auto [operand, type] = pair;
1433 printer << operand << " : " << type;
1434 });
1435 printer << ")";
1436}
1437
1438//===----------------------------------------------------------------------===//
1439// ShuffleOp
1440//===----------------------------------------------------------------------===//
1441
1442void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1443 int32_t offset, int32_t width, ShuffleMode mode) {
1444 build(builder, result, value,
1445 arith::ConstantOp::create(builder, result.location,
1446 builder.getI32IntegerAttr(offset)),
1447 arith::ConstantOp::create(builder, result.location,
1448 builder.getI32IntegerAttr(width)),
1449 mode);
1450}
1451
1452//===----------------------------------------------------------------------===//
1453// RotateOp
1454//===----------------------------------------------------------------------===//
1455
1456LogicalResult RotateOp::verify() {
1457 uint32_t offset = getOffset();
1458 uint32_t width = getWidth();
1459
1460 if (offset >= width) {
1461 return emitOpError() << "offset must be in the range [0, " << width << ")";
1462 }
1463
1464 return success();
1465}
1466
1467//===----------------------------------------------------------------------===//
1468// BarrierOp
1469//===----------------------------------------------------------------------===//
1470
1471namespace {
1472
1473/// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1474LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1475 PatternRewriter &rewriter) {
1476 if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1477 rewriter.eraseOp(op);
1478 return success();
1479 }
1480 return failure();
1481}
1482
1483} // end anonymous namespace
1484
1485void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1486 MLIRContext *context) {
1487 results.add(eraseRedundantGpuBarrierOps);
1488}
1489
1490//===----------------------------------------------------------------------===//
1491// GPUFuncOp
1492//===----------------------------------------------------------------------===//
1493
1494/// Adds a new block argument that corresponds to buffers located in
1495/// workgroup memory.
1496BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1497 auto attrName = getNumWorkgroupAttributionsAttrName();
1498 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1499 (*this)->setAttr(attrName,
1500 IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1501 return getBody().insertArgument(
1502 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1503}
1504
1505/// Adds a new block argument that corresponds to buffers located in
1506/// private memory.
1507BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1508 // Buffers on the private memory always come after buffers on the workgroup
1509 // memory.
1510 return getBody().addArgument(type, loc);
1511}
1512
1513void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1514 StringRef name, FunctionType type,
1515 TypeRange workgroupAttributions,
1516 TypeRange privateAttributions,
1517 ArrayRef<NamedAttribute> attrs) {
1518 OpBuilder::InsertionGuard g(builder);
1519
1521 builder.getStringAttr(name));
1522 result.addAttribute(getFunctionTypeAttrName(result.name),
1523 TypeAttr::get(type));
1524 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1525 builder.getI64IntegerAttr(workgroupAttributions.size()));
1526 result.addAttributes(attrs);
1527 Region *body = result.addRegion();
1528 Block *entryBlock = builder.createBlock(body);
1529
1530 // TODO: Allow passing in proper locations here.
1531 for (Type argTy : type.getInputs())
1532 entryBlock->addArgument(argTy, result.location);
1533 for (Type argTy : workgroupAttributions)
1534 entryBlock->addArgument(argTy, result.location);
1535 for (Type argTy : privateAttributions)
1536 entryBlock->addArgument(argTy, result.location);
1537}
1538
1539/// Parses a GPU function memory attribution.
1540///
1541/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1542/// (`private` `(` ssa-id-and-type-list `)`)?
1543///
1544/// Note that this function parses only one of the two similar parts, with the
1545/// keyword provided as argument.
1546static ParseResult
1547parseAttributions(OpAsmParser &parser, StringRef keyword,
1549 Attribute &attributionAttrs) {
1550 // If we could not parse the keyword, just assume empty list and succeed.
1551 if (failed(parser.parseOptionalKeyword(keyword)))
1552 return success();
1553
1554 size_t existingArgs = args.size();
1555 ParseResult result =
1557 /*allowType=*/true, /*allowAttrs=*/true);
1558 if (failed(result))
1559 return result;
1560
1561 bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1562 [](const OpAsmParser::Argument &arg) -> bool {
1563 return arg.attrs && !arg.attrs.empty();
1564 });
1565 if (!hadAttrs) {
1566 attributionAttrs = nullptr;
1567 return result;
1568 }
1569
1570 Builder &builder = parser.getBuilder();
1571 SmallVector<Attribute> attributionAttrsVec;
1572 for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1573 if (!argument.attrs)
1574 attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1575 else
1576 attributionAttrsVec.push_back(argument.attrs);
1577 }
1578 attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1579 return result;
1580}
1581
1582/// Parses a GPU function.
1583///
1584/// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1585/// (`->` function-result-list)? memory-attribution `kernel`?
1586/// function-attributes? region
1587ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
1588 SmallVector<OpAsmParser::Argument> entryArgs;
1589 SmallVector<DictionaryAttr> resultAttrs;
1590 SmallVector<Type> resultTypes;
1591 bool isVariadic;
1592
1593 // Parse the function name.
1594 StringAttr nameAttr;
1596 result.attributes))
1597 return failure();
1598
1599 auto signatureLocation = parser.getCurrentLocation();
1601 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1602 resultAttrs)))
1603 return failure();
1604
1605 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1606 return parser.emitError(signatureLocation)
1607 << "gpu.func requires named arguments";
1608
1609 // Construct the function type. More types will be added to the region, but
1610 // not to the function type.
1611 Builder &builder = parser.getBuilder();
1612
1613 SmallVector<Type> argTypes;
1614 for (auto &arg : entryArgs)
1615 argTypes.push_back(arg.type);
1616 auto type = builder.getFunctionType(argTypes, resultTypes);
1617 result.addAttribute(getFunctionTypeAttrName(result.name),
1618 TypeAttr::get(type));
1619
1621 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1622 getResAttrsAttrName(result.name));
1623
1624 Attribute workgroupAttributionAttrs;
1625 // Parse workgroup memory attributions.
1626 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1627 entryArgs, workgroupAttributionAttrs)))
1628 return failure();
1629
1630 // Store the number of operands we just parsed as the number of workgroup
1631 // memory attributions.
1632 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1633 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1634 builder.getI64IntegerAttr(numWorkgroupAttrs));
1635 if (workgroupAttributionAttrs)
1636 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1637 workgroupAttributionAttrs);
1638
1639 Attribute privateAttributionAttrs;
1640 // Parse private memory attributions.
1641 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1642 entryArgs, privateAttributionAttrs)))
1643 return failure();
1644 if (privateAttributionAttrs)
1645 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1646 privateAttributionAttrs);
1647
1648 // Parse the kernel attribute if present.
1649 if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1650 result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1651 builder.getUnitAttr());
1652
1653 // Parse attributes.
1654 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1655 return failure();
1656
1657 // Parse the region. If no argument names were provided, take all names
1658 // (including those of attributions) from the entry block.
1659 auto *body = result.addRegion();
1660 return parser.parseRegion(*body, entryArgs);
1661}
1662
1663void GPUFuncOp::print(OpAsmPrinter &p) {
1664 p << ' ';
1665 p.printSymbolName(getName());
1666
1667 FunctionType type = getFunctionType();
1668 function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1669 /*isVariadic=*/false,
1670 type.getResults());
1671
1672 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1673 getWorkgroupAttribAttrs().value_or(nullptr));
1674 printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1675 getPrivateAttribAttrs().value_or(nullptr));
1676 if (isKernel())
1677 p << ' ' << getKernelKeyword();
1678
1680 p, *this,
1681 {getNumWorkgroupAttributionsAttrName(),
1682 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1683 getArgAttrsAttrName(), getResAttrsAttrName(),
1684 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1685 p << ' ';
1686 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1687}
1688
1689static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1690 StringAttr attrName) {
1691 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1692 if (!allAttrs || index >= allAttrs.size())
1693 return DictionaryAttr();
1694 return llvm::cast<DictionaryAttr>(allAttrs[index]);
1695}
1696
1697DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1698 return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1699}
1700
1701DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1702 return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1703}
1704
1705static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1706 DictionaryAttr value, StringAttr attrName) {
1707 MLIRContext *ctx = op.getContext();
1708 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1709 SmallVector<Attribute> elements;
1710 if (allAttrs)
1711 elements.append(allAttrs.begin(), allAttrs.end());
1712 while (elements.size() <= index)
1713 elements.push_back(DictionaryAttr::get(ctx));
1714 if (!value)
1715 elements[index] = DictionaryAttr::get(ctx);
1716 else
1717 elements[index] = value;
1718 ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1719 op->setAttr(attrName, newValue);
1720}
1721
1722void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1723 DictionaryAttr value) {
1724 setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1725}
1726
1727void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1728 DictionaryAttr value) {
1729 setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1730}
1731
1732static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1733 StringAttr name, StringAttr attrsName) {
1734 DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1735 if (!dict)
1736 return Attribute();
1737 return dict.get(name);
1738}
1739
1740Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1741 StringAttr name) {
1742 assert(index < getNumWorkgroupAttributions() &&
1743 "index must map to a workgroup attribution");
1744 return getAttributionAttr(*this, index, name,
1745 getWorkgroupAttribAttrsAttrName());
1746}
1747
1748Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1749 StringAttr name) {
1750 assert(index < getNumPrivateAttributions() &&
1751 "index must map to a private attribution");
1752 return getAttributionAttr(*this, index, name,
1753 getPrivateAttribAttrsAttrName());
1754}
1755
1756static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1757 Attribute value, StringAttr attrsName) {
1758 MLIRContext *ctx = op.getContext();
1760 DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1761 if (oldDict)
1762 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1763
1764 bool found = false;
1765 bool mustSort = true;
1766 for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1767 if (elems[i].getName() == name) {
1768 found = true;
1769 if (!value) {
1770 std::swap(elems[i], elems[elems.size() - 1]);
1771 elems.pop_back();
1772 } else {
1773 mustSort = false;
1774 elems[i] = NamedAttribute(elems[i].getName(), value);
1775 }
1776 break;
1777 }
1778 }
1779 if (!found) {
1780 if (!value)
1781 return;
1782 elems.emplace_back(name, value);
1783 }
1784 if (mustSort) {
1785 DictionaryAttr::sortInPlace(elems);
1786 }
1787 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1788 setAttributionAttrs(op, index, newDict, attrsName);
1789}
1790
1791void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1792 Attribute value) {
1793 assert(index < getNumWorkgroupAttributions() &&
1794 "index must map to a workgroup attribution");
1795 setAttributionAttr(*this, index, name, value,
1796 getWorkgroupAttribAttrsAttrName());
1797}
1798
1799void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1800 Attribute value) {
1801 assert(index < getNumPrivateAttributions() &&
1802 "index must map to a private attribution");
1803 setAttributionAttr(*this, index, name, value,
1804 getPrivateAttribAttrsAttrName());
1805}
1806
1807LogicalResult GPUFuncOp::verifyType() {
1808 if (isKernel() && getFunctionType().getNumResults() != 0)
1809 return emitOpError() << "expected void return type for kernel function";
1810
1811 return success();
1812}
1813
1814/// Verifies the body of the function.
1815LogicalResult GPUFuncOp::verifyBody() {
1816 if (empty())
1817 return emitOpError() << "expected body with at least one block";
1818 unsigned numFuncArguments = getNumArguments();
1819 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1820 unsigned numBlockArguments = front().getNumArguments();
1821 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1822 return emitOpError() << "expected at least "
1823 << numFuncArguments + numWorkgroupAttributions
1824 << " arguments to body region";
1825
1826 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1827 for (unsigned i = 0; i < numFuncArguments; ++i) {
1828 Type blockArgType = front().getArgument(i).getType();
1829 if (funcArgTypes[i] != blockArgType)
1830 return emitOpError() << "expected body region argument #" << i
1831 << " to be of type " << funcArgTypes[i] << ", got "
1832 << blockArgType;
1833 }
1834
1835 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1836 GPUDialect::getWorkgroupAddressSpace())) ||
1837 failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1838 GPUDialect::getPrivateAddressSpace())))
1839 return failure();
1840
1841 return success();
1842}
1843
1844//===----------------------------------------------------------------------===//
1845// ReturnOp
1846//===----------------------------------------------------------------------===//
1847
1848LogicalResult gpu::ReturnOp::verify() {
1849 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1850
1851 FunctionType funType = function.getFunctionType();
1852
1853 if (funType.getNumResults() != getOperands().size())
1854 return emitOpError()
1855 .append("expected ", funType.getNumResults(), " result operands")
1856 .attachNote(function.getLoc())
1857 .append("return type declared here");
1858
1859 for (const auto &pair : llvm::enumerate(
1860 llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1861 auto [type, operand] = pair.value();
1862 if (type != operand.getType())
1863 return emitOpError() << "unexpected type `" << operand.getType()
1864 << "' for operand #" << pair.index();
1865 }
1866 return success();
1867}
1868
1869//===----------------------------------------------------------------------===//
1870// GPUModuleOp
1871//===----------------------------------------------------------------------===//
1872
1873void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1874 StringRef name, ArrayAttr targets,
1875 Attribute offloadingHandler) {
1876 result.addRegion()->emplaceBlock();
1877 Properties &props = result.getOrAddProperties<Properties>();
1878 if (targets)
1879 props.targets = targets;
1880 props.setSymName(builder.getStringAttr(name));
1881 props.offloadingHandler = offloadingHandler;
1882}
1883
1884void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1885 StringRef name, ArrayRef<Attribute> targets,
1886 Attribute offloadingHandler) {
1887 build(builder, result, name,
1888 targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1889 offloadingHandler);
1890}
1891
1892bool GPUModuleOp::hasTarget(Attribute target) {
1893 if (ArrayAttr targets = getTargetsAttr())
1894 return llvm::count(targets.getValue(), target);
1895 return false;
1896}
1897
1898void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1899 ArrayAttr &targetsAttr = getProperties().targets;
1900 SmallVector<Attribute> targetsVector(targets);
1901 targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1902}
1903
1904LogicalResult GPUModuleOp::verify() {
1905 auto targets = getOperation()->getAttrOfType<ArrayAttr>("targets");
1906
1907 if (!targets)
1908 return success();
1909
1910 for (auto target : targets) {
1911 if (auto verifyTargetAttr =
1912 llvm::dyn_cast<TargetAttrVerifyInterface>(target)) {
1913 if (verifyTargetAttr.verifyTarget(getOperation()).failed())
1914 return failure();
1915 }
1916 }
1917 return success();
1918}
1919
1920//===----------------------------------------------------------------------===//
1921// GPUBinaryOp
1922//===----------------------------------------------------------------------===//
1923void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1924 Attribute offloadingHandler, ArrayAttr objects) {
1925 auto &properties = result.getOrAddProperties<Properties>();
1926 result.attributes.push_back(builder.getNamedAttr(
1928 properties.objects = objects;
1929 if (offloadingHandler)
1930 properties.offloadingHandler = offloadingHandler;
1931 else
1932 properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
1933}
1934
1935void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1936 Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1937 build(builder, result, name, offloadingHandler,
1938 objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
1939}
1940
1941static ParseResult parseOffloadingHandler(OpAsmParser &parser,
1942 Attribute &offloadingHandler) {
1943 if (succeeded(parser.parseOptionalLess())) {
1944 if (parser.parseAttribute(offloadingHandler))
1945 return failure();
1946 if (parser.parseGreater())
1947 return failure();
1948 }
1949 if (!offloadingHandler)
1950 offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
1951 return success();
1952}
1953
1955 Attribute offloadingHandler) {
1956 if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
1957 printer << '<' << offloadingHandler << '>';
1958}
1959
1960//===----------------------------------------------------------------------===//
1961// GPUMemcpyOp
1962//===----------------------------------------------------------------------===//
1963
1964LogicalResult MemcpyOp::verify() {
1965 auto srcType = getSrc().getType();
1966 auto dstType = getDst().getType();
1967
1968 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1969 return emitOpError("arguments have incompatible element type");
1970
1971 if (failed(verifyCompatibleShape(srcType, dstType)))
1972 return emitOpError("arguments have incompatible shape");
1973
1974 return success();
1975}
1976
1977namespace {
1978
1979/// Erases a common case of copy ops where a destination value is used only by
1980/// the copy op, alloc and dealloc ops.
1981struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1982 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1983
1984 LogicalResult matchAndRewrite(MemcpyOp op,
1985 PatternRewriter &rewriter) const override {
1986 Value dest = op.getDst();
1987 Operation *destDefOp = dest.getDefiningOp();
1988 // `dest` must be defined by an op having Allocate memory effect in order to
1989 // perform the folding.
1990 if (!destDefOp ||
1992 return failure();
1993 // We can erase `op` iff `dest` has no other use apart from its
1994 // use by `op` and dealloc ops.
1995 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1996 return user != op &&
1997 !hasSingleEffect<MemoryEffects::Free>(user, dest);
1998 }))
1999 return failure();
2000 // We can perform the folding if and only if op has a single async
2001 // dependency and produces an async token as result, or if it does not have
2002 // any async dependency and does not produce any async token result.
2003 if (op.getAsyncDependencies().size() > 1 ||
2004 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2005 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2006 return failure();
2007 rewriter.replaceOp(op, op.getAsyncDependencies());
2008 return success();
2009 }
2010};
2011
2012} // end anonymous namespace
2013
2014void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2015 MLIRContext *context) {
2016 results.add<EraseTrivialCopyOp>(context);
2017}
2018
2019//===----------------------------------------------------------------------===//
2020// GPU_SubgroupMmaLoadMatrixOp
2021//===----------------------------------------------------------------------===//
2022
2023LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2024 auto srcType = getSrcMemref().getType();
2025 auto resType = getRes().getType();
2026 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2027 auto operand = resMatrixType.getOperand();
2028 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2029
2030 if (!srcMemrefType.isLastDimUnitStride())
2031 return emitError(
2032 "expected source memref most minor dim must have unit stride");
2033
2034 if (operand != "AOp" && operand != "BOp" && operand != "COp")
2035 return emitError("only AOp, BOp and COp can be loaded");
2036
2037 return success();
2038}
2039
2040//===----------------------------------------------------------------------===//
2041// GPU_SubgroupMmaStoreMatrixOp
2042//===----------------------------------------------------------------------===//
2043
2044LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2045 auto srcType = getSrc().getType();
2046 auto dstType = getDstMemref().getType();
2047 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2048 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2049
2050 if (!dstMemrefType.isLastDimUnitStride())
2051 return emitError(
2052 "expected destination memref most minor dim must have unit stride");
2053
2054 if (srcMatrixType.getOperand() != "COp")
2055 return emitError(
2056 "expected the operand matrix being stored to have 'COp' operand type");
2057
2058 return success();
2059}
2060
2061//===----------------------------------------------------------------------===//
2062// GPU_SubgroupMmaComputeOp
2063//===----------------------------------------------------------------------===//
2064
2065LogicalResult SubgroupMmaComputeOp::verify() {
2066 enum OperandMap { A, B, C };
2067 SmallVector<MMAMatrixType, 3> opTypes;
2068 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
2069 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
2070 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
2071
2072 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
2073 opTypes[C].getOperand() != "COp")
2074 return emitError("operands must be in the order AOp, BOp, COp");
2075
2076 ArrayRef<int64_t> aShape, bShape, cShape;
2077 aShape = opTypes[A].getShape();
2078 bShape = opTypes[B].getShape();
2079 cShape = opTypes[C].getShape();
2080
2081 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2082 bShape[1] != cShape[1])
2083 return emitError("operand shapes do not satisfy matmul constraints");
2084
2085 return success();
2086}
2087
2088LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2089 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2090 return memref::foldMemRefCast(*this);
2091}
2092
2093LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2094 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2095 return memref::foldMemRefCast(*this);
2096}
2097
2098//===----------------------------------------------------------------------===//
2099// GPU_WaitOp
2100//===----------------------------------------------------------------------===//
2101
2102namespace {
2103
2104/// Remove gpu.wait op use of gpu.wait op def without async dependencies.
2105/// %t = gpu.wait async [] // No async dependencies.
2106/// ... gpu.wait ... [%t, ...] // %t can be removed.
2107struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2108public:
2110
2111 LogicalResult matchAndRewrite(WaitOp op,
2112 PatternRewriter &rewriter) const final {
2113 auto predicate = [](Value value) {
2114 auto waitOp = value.getDefiningOp<WaitOp>();
2115 return waitOp && waitOp->getNumOperands() == 0;
2116 };
2117 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2118 return failure();
2119 SmallVector<Value> validOperands;
2120 for (Value operand : op->getOperands()) {
2121 if (predicate(operand))
2122 continue;
2123 validOperands.push_back(operand);
2124 }
2125 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2126 return success();
2127 }
2128};
2129
2130/// Simplify trivial gpu.wait ops for the following patterns.
2131/// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2132/// dependencies).
2133/// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2134/// %t0.
2135/// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2136/// dependencies nor return any token.
2137struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2138public:
2140
2141 LogicalResult matchAndRewrite(WaitOp op,
2142 PatternRewriter &rewriter) const final {
2143 // Erase gpu.wait ops that neither have any async dependencies nor return
2144 // any async token.
2145 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2146 rewriter.eraseOp(op);
2147 return success();
2148 }
2149 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2150 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2151 op.getAsyncToken()) {
2152 rewriter.replaceOp(op, op.getAsyncDependencies());
2153 return success();
2154 }
2155 // Erase %t = gpu.wait async ... ops, where %t has no uses.
2156 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2157 rewriter.eraseOp(op);
2158 return success();
2159 }
2160 return failure();
2161 }
2162};
2163
2164} // end anonymous namespace
2165
2166void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2167 MLIRContext *context) {
2168 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2169}
2170
2171//===----------------------------------------------------------------------===//
2172// GPU_AllocOp
2173//===----------------------------------------------------------------------===//
2174
2175LogicalResult AllocOp::verify() {
2176 auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2177
2178 if (getDynamicSizes().size() != memRefType.getNumDynamicDims())
2179 return emitOpError("dimension operand count does not equal memref "
2180 "dynamic dimension count");
2181
2182 unsigned numSymbols = 0;
2183 if (!memRefType.getLayout().isIdentity())
2184 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2185 if (getSymbolOperands().size() != numSymbols) {
2186 return emitOpError(
2187 "symbol operand count does not equal memref symbol count");
2188 }
2189
2190 return success();
2191}
2192
2193namespace {
2194
2195/// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2196/// `memref::AllocOp`.
2197struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2198 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2199
2200 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2201 PatternRewriter &rewriter) const override {
2202 std::optional<int64_t> index = dimOp.getConstantIndex();
2203 if (!index)
2204 return failure();
2205
2206 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2207 if (!memrefType || index.value() >= memrefType.getRank() ||
2208 !memrefType.isDynamicDim(index.value()))
2209 return failure();
2210
2211 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2212 if (!alloc)
2213 return failure();
2214
2215 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2216 memrefType.getDynamicDimIndex(index.value()));
2217 rewriter.replaceOp(dimOp, substituteOp);
2218 return success();
2219 }
2220};
2221
2222} // namespace
2223
2224void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2225 MLIRContext *context) {
2226 results.add<SimplifyDimOfAllocOp>(context);
2227}
2228
2229//===----------------------------------------------------------------------===//
2230// GPU object attribute
2231//===----------------------------------------------------------------------===//
2232
2233LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2234 Attribute target, CompilationTarget format,
2235 StringAttr object, DictionaryAttr properties,
2236 KernelTableAttr kernels) {
2237 if (!target)
2238 return emitError() << "the target attribute cannot be null";
2239 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2240 return success();
2241 return emitError() << "the target attribute must implement or promise the "
2242 "`gpu::TargetAttrInterface`";
2243}
2244
2245namespace {
2246ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2247 StringAttr &object) {
2248 std::optional<CompilationTarget> formatResult;
2249 StringRef enumKeyword;
2250 auto loc = odsParser.getCurrentLocation();
2251 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2252 formatResult = CompilationTarget::Fatbin;
2253 if (!formatResult &&
2254 (formatResult =
2255 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2256 odsParser.parseEqual())
2257 return odsParser.emitError(loc, "expected an equal sign");
2258 if (!formatResult)
2259 return odsParser.emitError(loc, "expected keyword for GPU object format");
2260 FailureOr<StringAttr> objectResult =
2261 FieldParser<StringAttr>::parse(odsParser);
2262 if (failed(objectResult))
2263 return odsParser.emitError(odsParser.getCurrentLocation(),
2264 "failed to parse GPU_ObjectAttr parameter "
2265 "'object' which is to be a `StringAttr`");
2266 format = *formatResult;
2267 object = *objectResult;
2268 return success();
2269}
2270
2271void printObject(AsmPrinter &odsParser, CompilationTarget format,
2272 StringAttr object) {
2273 if (format != CompilationTarget::Fatbin)
2274 odsParser << stringifyEnum(format) << " = ";
2275 odsParser << object;
2276}
2277} // namespace
2278
2279//===----------------------------------------------------------------------===//
2280// GPU select object attribute
2281//===----------------------------------------------------------------------===//
2282
2283LogicalResult
2284gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2285 Attribute target) {
2286 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2287 if (target) {
2288 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2289 if (intAttr.getInt() < 0) {
2290 return emitError() << "the object index must be positive";
2291 }
2292 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2293 return emitError()
2294 << "the target attribute must be a GPU Target attribute";
2295 }
2296 }
2297 return success();
2298}
2299
2300//===----------------------------------------------------------------------===//
2301// DynamicSharedMemoryOp
2302//===----------------------------------------------------------------------===//
2303
2304LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2305 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2306 return emitOpError() << "must be inside an op with symbol table";
2307
2308 MemRefType memrefType = getResultMemref().getType();
2309 // Check address space
2310 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2311 return emitOpError() << "address space must be "
2312 << gpu::AddressSpaceAttr::getMnemonic() << "<"
2313 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2314 }
2315 if (memrefType.hasStaticShape()) {
2316 return emitOpError() << "result memref type must be memref<?xi8, "
2317 "#gpu.address_space<workgroup>>";
2318 }
2319 return success();
2320}
2321
2322//===----------------------------------------------------------------------===//
2323// GPU WarpExecuteOnLane0Op
2324//===----------------------------------------------------------------------===//
2325
2326void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2327 p << "(" << getLaneid() << ")";
2328
2329 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2330 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2331 p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2332
2333 if (!getArgs().empty())
2334 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2335 if (!getResults().empty())
2336 p << " -> (" << getResults().getTypes() << ')';
2337 p << " ";
2338 p.printRegion(getRegion(),
2339 /*printEntryBlockArgs=*/true,
2340 /*printBlockTerminators=*/!getResults().empty());
2341 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2342}
2343
2344ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2345 OperationState &result) {
2346 // Create the region.
2347 result.regions.reserve(1);
2348 Region *warpRegion = result.addRegion();
2349
2350 auto &builder = parser.getBuilder();
2351 OpAsmParser::UnresolvedOperand laneId;
2352
2353 // Parse predicate operand.
2354 if (parser.parseLParen() ||
2355 parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2356 parser.parseRParen())
2357 return failure();
2358
2359 int64_t warpSize;
2360 if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2361 parser.parseRSquare())
2362 return failure();
2363 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2364 builder.getContext())),
2365 builder.getI64IntegerAttr(warpSize));
2366
2367 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2368 return failure();
2369
2370 llvm::SMLoc inputsOperandsLoc;
2371 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2372 SmallVector<Type> inputTypes;
2373 if (succeeded(parser.parseOptionalKeyword("args"))) {
2374 if (parser.parseLParen())
2375 return failure();
2376
2377 inputsOperandsLoc = parser.getCurrentLocation();
2378 if (parser.parseOperandList(inputsOperands) ||
2379 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2380 return failure();
2381 }
2382 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2383 result.operands))
2384 return failure();
2385
2386 // Parse optional results type list.
2387 if (parser.parseOptionalArrowTypeList(result.types))
2388 return failure();
2389 // Parse the region.
2390 if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2391 /*argTypes=*/{}))
2392 return failure();
2393 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2394
2395 // Parse the optional attribute list.
2396 if (parser.parseOptionalAttrDict(result.attributes))
2397 return failure();
2398 return success();
2399}
2400
2401void WarpExecuteOnLane0Op::getSuccessorRegions(
2402 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2403 if (!point.isParent()) {
2404 regions.push_back(RegionSuccessor(getOperation(), getResults()));
2405 return;
2406 }
2407
2408 // The warp region is always executed
2409 regions.push_back(RegionSuccessor(&getWarpRegion()));
2410}
2411
2412void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2413 TypeRange resultTypes, Value laneId,
2414 int64_t warpSize) {
2415 build(builder, result, resultTypes, laneId, warpSize,
2416 /*operands=*/{}, /*argTypes=*/{});
2417}
2418
2419void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2420 TypeRange resultTypes, Value laneId,
2421 int64_t warpSize, ValueRange args,
2422 TypeRange blockArgTypes) {
2423 result.addOperands(laneId);
2424 result.addAttribute(getAttributeNames()[0],
2425 builder.getI64IntegerAttr(warpSize));
2426 result.addTypes(resultTypes);
2427 result.addOperands(args);
2428 assert(args.size() == blockArgTypes.size());
2429 OpBuilder::InsertionGuard guard(builder);
2430 Region *warpRegion = result.addRegion();
2431 Block *block = builder.createBlock(warpRegion);
2432 for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2433 block->addArgument(type, arg.getLoc());
2434}
2435
2436/// Helper check if the distributed vector type is consistent with the expanded
2437/// type and distributed size.
2438static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2439 int64_t warpSize, Operation *op) {
2440 // If the types matches there is no distribution.
2441 if (expanded == distributed)
2442 return success();
2443 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2444 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2445 if (!expandedVecType || !distributedVecType)
2446 return op->emitOpError("expected vector type for distributed operands.");
2447 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2448 expandedVecType.getElementType() != distributedVecType.getElementType())
2449 return op->emitOpError(
2450 "expected distributed vectors to have same rank and element type.");
2451
2452 SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2453 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2454 int64_t eDim = expandedVecType.getDimSize(i);
2455 int64_t dDim = distributedVecType.getDimSize(i);
2456 if (eDim == dDim)
2457 continue;
2458 if (eDim % dDim != 0)
2459 return op->emitOpError()
2460 << "expected expanded vector dimension #" << i << " (" << eDim
2461 << ") to be a multipler of the distributed vector dimension ("
2462 << dDim << ")";
2463 scales[i] = eDim / dDim;
2464 }
2465 if (llvm::product_of(scales) != warpSize)
2466 return op->emitOpError()
2467 << "incompatible distribution dimensions from " << expandedVecType
2468 << " to " << distributedVecType << " with warp size = " << warpSize;
2469
2470 return success();
2471}
2472
2473LogicalResult WarpExecuteOnLane0Op::verify() {
2474 if (getArgs().size() != getWarpRegion().getNumArguments())
2475 return emitOpError(
2476 "expected same number op arguments and block arguments.");
2477 gpu::YieldOp yield = getTerminator();
2478 if (yield.getNumOperands() != getNumResults())
2479 return emitOpError(
2480 "expected same number of yield operands and return values.");
2481 int64_t warpSize = getWarpSize();
2482 for (auto [regionArg, arg] :
2483 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2484 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2485 warpSize, getOperation())))
2486 return failure();
2487 }
2488 for (auto [yieldOperand, result] :
2489 llvm::zip_equal(yield.getOperands(), getResults())) {
2490 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2491 warpSize, getOperation())))
2492 return failure();
2493 }
2494 return success();
2495}
2496bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2497 return succeeded(
2498 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2499}
2500
2501gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2502 return cast<gpu::YieldOp>(getBody()->getTerminator());
2503}
2504
2505//===----------------------------------------------------------------------===//
2506// GPU_SubgroupBroadcastOp
2507//===----------------------------------------------------------------------===//
2508
2509void gpu::SubgroupBroadcastOp::inferResultRanges(
2510 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
2511 setResultRange(getResult(), argRanges.front());
2512}
2513
2514Speculation::Speculatability gpu::SubgroupBroadcastOp::getSpeculatability() {
2515 switch (getBroadcastType()) {
2516 case BroadcastType::first_active_lane:
2517 // Cannot speculate first_lane broadcast, because speculating it across
2518 // control flow can change the active lanes.
2520 case BroadcastType::specific_lane:
2521 // Speculation should be safe as long as we inside structured control flow.
2523 }
2524}
2525
2526LogicalResult gpu::SubgroupBroadcastOp::verify() {
2527 switch (getBroadcastType()) {
2528 case BroadcastType::first_active_lane:
2529 if (getLane())
2530 return emitOpError()
2531 << "lane can only be specified for `specific_lane` broadcast";
2532 return success();
2533 case BroadcastType::specific_lane:
2534 if (!getLane())
2535 return emitOpError()
2536 << "lane must be specified for `specific_lane` broadcast";
2537 return success();
2538 }
2539}
2540
2541OpFoldResult gpu::SubgroupBroadcastOp::fold(FoldAdaptor /*adaptor*/) {
2542 // Broadcast result is always uniform.
2543 if (auto prev = getSrc().getDefiningOp<SubgroupBroadcastOp>())
2544 return prev.getResult();
2545
2546 return nullptr;
2547}
2548
2549//===----------------------------------------------------------------------===//
2550// GPU KernelMetadataAttr
2551//===----------------------------------------------------------------------===//
2552
2553KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2554 DictionaryAttr metadata) {
2555 assert(kernel && "invalid kernel");
2556 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2557 kernel.getAllArgAttrs(), metadata);
2558}
2559
2560KernelMetadataAttr
2561KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2562 FunctionOpInterface kernel,
2563 DictionaryAttr metadata) {
2564 assert(kernel && "invalid kernel");
2565 return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2566 kernel.getAllArgAttrs(), metadata);
2567}
2568
2569KernelMetadataAttr
2570KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2571 if (attrs.empty())
2572 return *this;
2573 NamedAttrList attrList;
2574 if (DictionaryAttr dict = getMetadata())
2575 attrList.append(dict);
2576 attrList.append(attrs);
2577 return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2578 attrList.getDictionary(getContext()));
2579}
2580
2581LogicalResult
2582KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2583 StringAttr name, Type functionType,
2584 ArrayAttr argAttrs, DictionaryAttr metadata) {
2585 if (name.empty())
2586 return emitError() << "the kernel name can't be empty";
2587 if (argAttrs) {
2588 if (llvm::any_of(argAttrs, [](Attribute attr) {
2589 return !llvm::isa<DictionaryAttr>(attr);
2590 }))
2591 return emitError()
2592 << "all attributes in the array must be a dictionary attribute";
2593 }
2594 return success();
2595}
2596
2597//===----------------------------------------------------------------------===//
2598// GPU KernelTableAttr
2599//===----------------------------------------------------------------------===//
2600
2601KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2602 ArrayRef<KernelMetadataAttr> kernels,
2603 bool isSorted) {
2604 // Note that `is_sorted` is always only invoked once even with assertions ON.
2605 assert((!isSorted || llvm::is_sorted(kernels)) &&
2606 "expected a sorted kernel array");
2607 // Immediately return the attribute if the array is sorted.
2608 if (isSorted || llvm::is_sorted(kernels))
2609 return Base::get(context, kernels);
2610 // Sort the array.
2611 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2612 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2613 return Base::get(context, kernelsTmp);
2614}
2615
2616KernelTableAttr KernelTableAttr::getChecked(
2617 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2618 ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2619 // Note that `is_sorted` is always only invoked once even with assertions ON.
2620 assert((!isSorted || llvm::is_sorted(kernels)) &&
2621 "expected a sorted kernel array");
2622 // Immediately return the attribute if the array is sorted.
2623 if (isSorted || llvm::is_sorted(kernels))
2624 return Base::getChecked(emitError, context, kernels);
2625 // Sort the array.
2626 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2627 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2628 return Base::getChecked(emitError, context, kernelsTmp);
2629}
2630
2631LogicalResult
2632KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2633 ArrayRef<KernelMetadataAttr> kernels) {
2634 if (kernels.size() < 2)
2635 return success();
2636 // Check that the kernels are uniquely named.
2637 if (std::adjacent_find(kernels.begin(), kernels.end(),
2638 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2639 return l.getName() == r.getName();
2640 }) != kernels.end()) {
2641 return emitError() << "expected all kernels to be uniquely named";
2642 }
2643 return success();
2644}
2645
2646KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2647 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2648 return found ? *iterator : KernelMetadataAttr();
2649}
2650
2651KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2652 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2653 return found ? *iterator : KernelMetadataAttr();
2654}
2655
2656//===----------------------------------------------------------------------===//
2657// GPU target options
2658//===----------------------------------------------------------------------===//
2659
2674
2692
2693TypeID TargetOptions::getTypeID() const { return typeID; }
2694
2695StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2696
2700
2701StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2702
2703StringRef TargetOptions::getELFSection() const { return elfSection; }
2704
2708
2709function_ref<void(llvm::Module &)>
2713
2714function_ref<void(llvm::Module &)>
2718
2719function_ref<void(llvm::Module &)>
2723
2725 return isaCallback;
2726}
2727
2728CompilationTarget TargetOptions::getCompilationTarget() const {
2729 return compilationTarget;
2730}
2731
2733 return CompilationTarget::Fatbin;
2734}
2735
2736std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2738 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2739 llvm::StringSaver stringSaver(options.first);
2740 StringRef opts = cmdOptions;
2741 // For a correct tokenization of the command line options `opts` must be
2742 // unquoted, otherwise the tokenization function returns a single string: the
2743 // unquoted `cmdOptions` -which is not the desired behavior.
2744 // Remove any quotes if they are at the beginning and end of the string:
2745 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2746 opts.consume_front("\""), opts.consume_back("\"");
2747 if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2748 opts.consume_front("'"), opts.consume_back("'");
2749#ifdef _WIN32
2750 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2751 /*MarkEOLs=*/false);
2752#else
2753 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2754 /*MarkEOLs=*/false);
2755#endif // _WIN32
2756 return options;
2757}
2758
2759std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2763
2764std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2766 size_t startPos = cmdOptions.find(startsWith);
2767 if (startPos == std::string::npos)
2768 return {llvm::BumpPtrAllocator(), SmallVector<const char *>()};
2769
2770 auto tokenized =
2771 tokenizeCmdOptions(cmdOptions.substr(startPos + startsWith.size()));
2772 cmdOptions.resize(startPos);
2773 return tokenized;
2774}
2775
2777
2778#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2779#include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2780
2781#define GET_ATTRDEF_CLASSES
2782#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2783
2784#define GET_OP_CLASSES
2785#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2786
2787#include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values, ArrayAttr attributes={})
static LogicalResult verifyDistributedType(Type expanded, Type distributed, int64_t warpSize, Operation *op)
Helper check if the distributed vector type is consistent with the expanded type and distributed size...
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition TypeID.h:323
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
UnitAttr getUnitAttr()
Definition Builders.cpp:98
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:76
IntegerType getI32Type()
Definition Builders.cpp:63
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition Builders.cpp:104
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:94
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:98
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool isParent() const
Returns true if branching from the parent op.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:76
bool isIndex() const
Definition Types.cpp:54
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
function_ref< void(llvm::Module &)> optimizedLlvmIRCallback
Callback invoked with LLVM IR for the device module after LLVM optimizations but before codegen.
function_ref< void(StringRef)> getISACallback() const
Returns the callback invoked with the target ISA for the device, for example PTX assembly.
TypeID getTypeID() const
Returns the typeID.
std::string toolkitPath
Path to the target toolkit.
SymbolTable * getSymbolTable() const
Returns the result of the getSymbolTableCallback callback or a nullptr if no callback was provided.
StringRef getELFSection() const
Returns the ELF section.
StringRef getCmdOptions() const
Returns the command line options.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
function_ref< void(StringRef)> isaCallback
Callback invoked with the target ISA for the device, for example PTX assembly.
CompilationTarget compilationTarget
Compilation process target format.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
function_ref< void(llvm::Module &)> initialLlvmIRCallback
Callback invoked with the initial LLVM IR for the device module.
ArrayRef< Attribute > getLibrariesToLink() const
Returns the LLVM libraries to link to.
TargetOptions(StringRef toolkitPath={}, ArrayRef< Attribute > librariesToLink={}, StringRef cmdOptions={}, StringRef elfSection={}, CompilationTarget compilationTarget=getDefaultCompilationTarget(), function_ref< SymbolTable *()> getSymbolTableCallback={}, function_ref< void(llvm::Module &)> initialLlvmIRCallback={}, function_ref< void(llvm::Module &)> linkedLlvmIRCallback={}, function_ref< void(llvm::Module &)> optimizedLlvmIRCallback={}, function_ref< void(StringRef)> isaCallback={})
Constructor initializing the toolkit path, the list of files to link to, extra command line options,...
function_ref< void(llvm::Module &)> getOptimizedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after LLVM optimizations but before c...
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
StringRef getToolkitPath() const
Returns the toolkit path.
SmallVector< Attribute > librariesToLink
List of files to link with the LLVM module.
function_ref< void(llvm::Module &)> linkedLlvmIRCallback
Callback invoked with LLVM IR for the device module after linking the device libraries.
function_ref< void(llvm::Module &)> getInitialLlvmIRCallback() const
Returns the callback invoked with the initial LLVM IR for the device module.
function_ref< SymbolTable *()> getSymbolTableCallback
Callback for obtaining the parent symbol table of all the GPU modules being serialized.
static CompilationTarget getDefaultCompilationTarget()
Returns the default compilation target: CompilationTarget::Fatbin.
function_ref< void(llvm::Module &)> getLinkedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after linking the device libraries.
std::string elfSection
ELF Section where the binary needs to be located.
CompilationTarget getCompilationTarget() const
Returns the compilation target.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
std::pair< IteratorT, bool > findAttrSorted(IteratorT first, IteratorT last, StringRef name)
Using llvm::lower_bound requires an extra string comparison to check whether the returned iterator po...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:45
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition GPUDialect.h:39