MLIR 22.0.0git
GPUDialect.cpp
Go to the documentation of this file.
1//===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the GPU kernel-related dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
14
19#include "mlir/IR/Attributes.h"
20#include "mlir/IR/Builders.h"
22#include "mlir/IR/BuiltinOps.h"
24#include "mlir/IR/Diagnostics.h"
26#include "mlir/IR/Matchers.h"
29#include "mlir/IR/SymbolTable.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/CommandLine.h"
38#include "llvm/Support/ErrorHandling.h"
39#include "llvm/Support/FormatVariadic.h"
40#include "llvm/Support/InterleavedRange.h"
41#include "llvm/Support/StringSaver.h"
42#include <cassert>
43#include <numeric>
44
45using namespace mlir;
46using namespace mlir::gpu;
47
48#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
49
50//===----------------------------------------------------------------------===//
51// GPU Device Mapping Attributes
52//===----------------------------------------------------------------------===//
53
54int64_t GPUBlockMappingAttr::getMappingId() const {
55 return static_cast<int64_t>(getBlock());
56}
57
58bool GPUBlockMappingAttr::isLinearMapping() const {
59 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
60}
61
62int64_t GPUBlockMappingAttr::getRelativeIndex() const {
63 return isLinearMapping()
64 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
65 : getMappingId();
66}
67
68int64_t GPUWarpgroupMappingAttr::getMappingId() const {
69 return static_cast<int64_t>(getWarpgroup());
70}
71
72bool GPUWarpgroupMappingAttr::isLinearMapping() const {
73 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
74}
75
76int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
77 return isLinearMapping()
78 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
79 : getMappingId();
80}
81
82int64_t GPUWarpMappingAttr::getMappingId() const {
83 return static_cast<int64_t>(getWarp());
84}
85
86bool GPUWarpMappingAttr::isLinearMapping() const {
87 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
88}
89
90int64_t GPUWarpMappingAttr::getRelativeIndex() const {
91 return isLinearMapping()
92 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
93 : getMappingId();
94}
95
96int64_t GPUThreadMappingAttr::getMappingId() const {
97 return static_cast<int64_t>(getThread());
98}
99
100bool GPUThreadMappingAttr::isLinearMapping() const {
101 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
102}
103
104int64_t GPUThreadMappingAttr::getRelativeIndex() const {
105 return isLinearMapping()
106 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
107 : getMappingId();
108}
109
110int64_t GPULaneMappingAttr::getMappingId() const {
111 return static_cast<int64_t>(getLane());
112}
113
114bool GPULaneMappingAttr::isLinearMapping() const {
115 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
116}
117
118int64_t GPULaneMappingAttr::getRelativeIndex() const {
119 return isLinearMapping()
120 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
121 : getMappingId();
122}
123
124int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds() const { return 64; }
125
126/// 8 4 0
127/// Example mask : 0 0 0 1 1 0 1 0 0
128///
129/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
130/// Logical id for e.g. 5 (2) constructs filter (1 << 5 - 1).
131///
132/// Example mask : 0 0 0 1 1 0 1 0 0
133/// Example filter: 0 0 0 0 1 1 1 1 1
134/// Intersection : 0 0 0 0 1 0 1 0 0
135/// PopCnt : 2
136Value GPUMappingMaskAttr::createLogicalLinearMappingId(
137 OpBuilder &b, Value physicalLinearMappingId) const {
138 Location loc = physicalLinearMappingId.getLoc();
139 Value mask =
140 arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
141 Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
142 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
143 filter = arith::SubIOp::create(b, loc, filter, one);
144 Value filteredId = arith::AndIOp::create(b, loc, mask, filter);
145 return math::CtPopOp::create(b, loc, filteredId);
146}
147
148/// 8 4 0
149/// Example mask : 0 0 0 1 1 0 1 0 0
150///
151/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
152/// Logical id for e.g. 5 (2) constructs filter (1 << 5).
153///
154/// Example mask : 0 0 0 1 1 0 1 0 0
155/// Example filter: 0 0 0 1 0 0 0 0 0
156/// Intersection : 0 0 0 1 0 0 0 0 0
157/// Cmp : 1
158Value GPUMappingMaskAttr::createIsActiveIdPredicate(
159 OpBuilder &b, Value physicalLinearMappingId) const {
160 Location loc = physicalLinearMappingId.getLoc();
161 Value mask =
162 arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
163 Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
164 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
165 Value filtered = arith::AndIOp::create(b, loc, mask, filter);
166 Value zero = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(0));
167 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ne, filtered,
168 zero);
169}
170
171int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
172 return static_cast<int64_t>(getAddressSpace());
173}
174
175bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
176 llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
177}
178
179int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
180 llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
181}
182
183//===----------------------------------------------------------------------===//
184// MMAMatrixType
185//===----------------------------------------------------------------------===//
186
188 StringRef operand) {
189 return Base::get(elementType.getContext(), shape, elementType, operand);
190}
191
194 ArrayRef<int64_t> shape, Type elementType,
195 StringRef operand) {
196 return Base::getChecked(emitError, elementType.getContext(), shape,
197 elementType, operand);
198}
199
200unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
201
203 return getImpl()->getShape();
204}
205
206Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
207
208StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
209
211 return elementType.isF16() || elementType.isF32() ||
212 elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
213 elementType.isInteger(32);
214}
215
216LogicalResult
218 ArrayRef<int64_t> shape, Type elementType,
219 StringRef operand) {
220 if (operand != "AOp" && operand != "BOp" && operand != "COp")
221 return emitError() << "operand expected to be one of AOp, BOp or COp";
222
223 if (shape.size() != 2)
224 return emitError() << "MMAMatrixType must have exactly two dimensions";
225
226 if (!MMAMatrixType::isValidElementType(elementType))
227 return emitError()
228 << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
229
230 return success();
231}
232
233//===----------------------------------------------------------------------===//
234// GPUDialect
235//===----------------------------------------------------------------------===//
236
237bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
238 if (!memorySpace)
239 return false;
240 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
241 return gpuAttr.getValue() == getWorkgroupAddressSpace();
242 return false;
243}
244
245bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
246 Attribute memorySpace = type.getMemorySpace();
247 return isWorkgroupMemoryAddressSpace(memorySpace);
248}
249
250bool GPUDialect::isKernel(Operation *op) {
251 UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
252 return static_cast<bool>(isKernelAttr);
253}
254
255namespace {
256/// This class defines the interface for handling inlining with gpu
257/// operations.
258struct GPUInlinerInterface : public DialectInlinerInterface {
260
261 /// All gpu dialect ops can be inlined.
262 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
263 return true;
264 }
265};
266} // namespace
267
268void GPUDialect::initialize() {
269 addTypes<AsyncTokenType>();
270 addTypes<MMAMatrixType>();
271 addTypes<SparseDnTensorHandleType>();
272 addTypes<SparseSpMatHandleType>();
273 addTypes<SparseSpGEMMOpHandleType>();
274 addOperations<
275#define GET_OP_LIST
276#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
277 >();
278 addAttributes<
279#define GET_ATTRDEF_LIST
280#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
281 >();
282 addInterfaces<GPUInlinerInterface>();
283 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
284 TerminatorOp>();
285 declarePromisedInterfaces<
286 ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,
287 ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,
288 SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();
289}
290
291static std::string getSparseHandleKeyword(SparseHandleKind kind) {
292 switch (kind) {
294 return "sparse.dntensor_handle";
296 return "sparse.spmat_handle";
298 return "sparse.spgemmop_handle";
299 }
300 llvm_unreachable("unknown sparse handle kind");
301 return "";
302}
303
304Type GPUDialect::parseType(DialectAsmParser &parser) const {
305 // Parse the main keyword for the type.
306 StringRef keyword;
307 if (parser.parseKeyword(&keyword))
308 return Type();
309 MLIRContext *context = getContext();
310
311 // Handle 'async token' types.
312 if (keyword == "async.token")
313 return AsyncTokenType::get(context);
314
315 if (keyword == "mma_matrix") {
316 SMLoc beginLoc = parser.getNameLoc();
317
318 // Parse '<'.
319 if (parser.parseLess())
320 return nullptr;
321
322 // Parse the size and elementType.
323 SmallVector<int64_t> shape;
324 Type elementType;
325 if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
326 parser.parseType(elementType))
327 return nullptr;
328
329 // Parse ','
330 if (parser.parseComma())
331 return nullptr;
332
333 // Parse operand.
334 std::string operand;
335 if (failed(parser.parseOptionalString(&operand)))
336 return nullptr;
337
338 // Parse '>'.
339 if (parser.parseGreater())
340 return nullptr;
341
343 parser.getEncodedSourceLoc(beginLoc)),
344 shape, elementType, operand);
345 }
346
348 return SparseDnTensorHandleType::get(context);
350 return SparseSpMatHandleType::get(context);
352 return SparseSpGEMMOpHandleType::get(context);
353
354 parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
355 return Type();
356}
357// TODO: print refined type here. Notice that should be corresponding to the
358// parser
359void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
360 TypeSwitch<Type>(type)
361 .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
362 .Case<SparseDnTensorHandleType>([&](Type) {
364 })
365 .Case<SparseSpMatHandleType>(
367 .Case<SparseSpGEMMOpHandleType>([&](Type) {
369 })
370 .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
371 os << "mma_matrix<";
372 auto shape = fragTy.getShape();
373 for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
374 os << *dim << 'x';
375 os << shape.back() << 'x' << fragTy.getElementType();
376 os << ", \"" << fragTy.getOperand() << "\"" << '>';
377 })
378 .DefaultUnreachable("unexpected 'gpu' type kind");
379}
380
381static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
382 NamedAttribute attr) {
383 auto array = dyn_cast<DenseI32ArrayAttr>(attr.getValue());
384 if (!array)
385 return op->emitOpError(Twine(attr.getName()) +
386 " must be a dense i32 array");
387 if (array.size() != 3)
388 return op->emitOpError(Twine(attr.getName()) +
389 " must contain exactly 3 elements");
390 return success();
391}
392
393LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
394 NamedAttribute attr) {
395 if (attr.getName() == getKnownBlockSizeAttrHelper().getName())
396 return verifyKnownLaunchSizeAttr(op, attr);
397 if (attr.getName() == getKnownGridSizeAttrHelper().getName())
398 return verifyKnownLaunchSizeAttr(op, attr);
399 if (!llvm::isa<UnitAttr>(attr.getValue()) ||
400 attr.getName() != getContainerModuleAttrName())
401 return success();
402
403 auto module = dyn_cast<ModuleOp>(op);
404 if (!module)
405 return op->emitError("expected '")
406 << getContainerModuleAttrName() << "' attribute to be attached to '"
407 << ModuleOp::getOperationName() << '\'';
408
409 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
410 // Ignore launches that are nested more or less deep than functions in the
411 // module we are currently checking.
412 if (!launchOp->getParentOp() ||
413 launchOp->getParentOp()->getParentOp() != module)
414 return success();
415
416 // Ignore launch ops with missing attributes here. The errors will be
417 // reported by the verifiers of those ops.
418 if (!launchOp->getAttrOfType<SymbolRefAttr>(
419 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
420 return success();
421
422 // Check that `launch_func` refers to a well-formed GPU kernel container.
423 StringAttr kernelContainerName = launchOp.getKernelModuleName();
424 Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
425 if (!kernelContainer)
426 return launchOp.emitOpError()
427 << "kernel container '" << kernelContainerName.getValue()
428 << "' is undefined";
429
430 // If the container is a GPU binary op return success.
431 if (isa<BinaryOp>(kernelContainer))
432 return success();
433
434 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
435 if (!kernelModule)
436 return launchOp.emitOpError()
437 << "kernel module '" << kernelContainerName.getValue()
438 << "' is undefined";
439
440 // Check that `launch_func` refers to a well-formed kernel function.
441 Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
442 if (!kernelFunc)
443 return launchOp.emitOpError("kernel function '")
444 << launchOp.getKernel() << "' is undefined";
445 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
446 if (!kernelConvertedFunction) {
447 InFlightDiagnostic diag = launchOp.emitOpError()
448 << "referenced kernel '" << launchOp.getKernel()
449 << "' is not a function";
450 diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
451 return diag;
452 }
453
454 if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
455 GPUDialect::getKernelFuncAttrName()))
456 return launchOp.emitOpError("kernel function is missing the '")
457 << GPUDialect::getKernelFuncAttrName() << "' attribute";
458
459 // TODO: If the kernel isn't a GPU function (which happens during separate
460 // compilation), do not check type correspondence as it would require the
461 // verifier to be aware of the type conversion.
462 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
463 if (!kernelGPUFunction)
464 return success();
465
466 unsigned actualNumArguments = launchOp.getNumKernelOperands();
467 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
468 if (expectedNumArguments != actualNumArguments)
469 return launchOp.emitOpError("got ")
470 << actualNumArguments << " kernel operands but expected "
471 << expectedNumArguments;
472
473 auto functionType = kernelGPUFunction.getFunctionType();
474 for (unsigned i = 0; i < expectedNumArguments; ++i) {
475 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
476 return launchOp.emitOpError("type of function argument ")
477 << i << " does not match";
478 }
479 }
480
481 return success();
482 });
483
484 return walkResult.wasInterrupted() ? failure() : success();
485}
486
487/// Parses an optional list of async operands with an optional leading keyword.
488/// (`async`)? (`[` ssa-id-list `]`)?
489///
490/// This method is used by the tablegen assembly format for async ops as well.
491static ParseResult parseAsyncDependencies(
492 OpAsmParser &parser, Type &asyncTokenType,
494 auto loc = parser.getCurrentLocation();
495 if (succeeded(parser.parseOptionalKeyword("async"))) {
496 if (parser.getNumResults() == 0)
497 return parser.emitError(loc, "needs to be named when marked 'async'");
498 asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
499 }
500 return parser.parseOperandList(asyncDependencies,
502}
503
504/// Prints optional async dependencies with its leading keyword.
505/// (`async`)? (`[` ssa-id-list `]`)?
506// Used by the tablegen assembly format for several async ops.
508 Type asyncTokenType,
509 OperandRange asyncDependencies) {
510 if (asyncTokenType)
511 printer << "async";
512 if (asyncDependencies.empty())
513 return;
514 if (asyncTokenType)
515 printer << ' ';
516 printer << llvm::interleaved_array(asyncDependencies);
517}
518
519// GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
520/// Parses a GPU function memory attribution.
521///
522/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
523/// (`private` `(` ssa-id-and-type-list `)`)?
524///
525/// Note that this function parses only one of the two similar parts, with the
526/// keyword provided as argument.
527static ParseResult
528parseAttributions(OpAsmParser &parser, StringRef keyword,
530 // If we could not parse the keyword, just assume empty list and succeed.
531 if (failed(parser.parseOptionalKeyword(keyword)))
532 return success();
533
535 /*allowType=*/true);
536}
537
538static void printAttributions(OpAsmPrinter &p, StringRef keyword,
540 ArrayAttr attributes = {}) {
541 if (values.empty())
542 return;
543
544 p << ' ' << keyword << '(';
545 llvm::interleaveComma(
546 llvm::enumerate(values), p, [&p, attributes](auto pair) {
547 BlockArgument v = pair.value();
548 p << v << " : " << v.getType();
549
550 size_t attributionIndex = pair.index();
551 DictionaryAttr attrs;
552 if (attributes && attributionIndex < attributes.size())
553 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
554 if (attrs)
555 p.printOptionalAttrDict(attrs.getValue());
556 });
557 p << ')';
558}
559
560/// Verifies a GPU function memory attribution.
561static LogicalResult verifyAttributions(Operation *op,
562 ArrayRef<BlockArgument> attributions,
563 gpu::AddressSpace memorySpace) {
564 for (Value v : attributions) {
565 auto type = llvm::dyn_cast<MemRefType>(v.getType());
566 if (!type)
567 return op->emitOpError() << "expected memref type in attribution";
568
569 // We can only verify the address space if it hasn't already been lowered
570 // from the AddressSpaceAttr to a target-specific numeric value.
571 auto addressSpace =
572 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
573 if (!addressSpace)
574 continue;
575 if (addressSpace.getValue() != memorySpace)
576 return op->emitOpError()
577 << "expected memory space " << stringifyAddressSpace(memorySpace)
578 << " in attribution";
579 }
580 return success();
581}
582
583//===----------------------------------------------------------------------===//
584// AllReduceOp
585//===----------------------------------------------------------------------===//
586
587static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
588 Type resType) {
589 using Kind = gpu::AllReduceOperation;
590 if (llvm::is_contained(
591 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
592 opName)) {
593 if (!isa<FloatType>(resType))
594 return failure();
595 }
596
597 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
598 Kind::AND, Kind::OR, Kind::XOR},
599 opName)) {
600 if (!isa<IntegerType>(resType))
601 return failure();
602 }
603
604 return success();
605}
606
607LogicalResult gpu::AllReduceOp::verifyRegions() {
608 if (getBody().empty() != getOp().has_value())
609 return emitError("expected either an op attribute or a non-empty body");
610 if (!getBody().empty()) {
611 if (getBody().getNumArguments() != 2)
612 return emitError("expected two region arguments");
613 for (auto argument : getBody().getArguments()) {
614 if (argument.getType() != getType())
615 return emitError("incorrect region argument type");
616 }
617 unsigned yieldCount = 0;
618 for (Block &block : getBody()) {
619 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
620 if (yield.getNumOperands() != 1)
621 return emitError("expected one gpu.yield operand");
622 if (yield.getOperand(0).getType() != getType())
623 return emitError("incorrect gpu.yield type");
624 ++yieldCount;
625 }
626 }
627 if (yieldCount == 0)
628 return emitError("expected gpu.yield op in region");
629 } else {
630 gpu::AllReduceOperation opName = *getOp();
631 if (failed(verifyReduceOpAndType(opName, getType()))) {
632 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
633 << "` reduction operation is not compatible with type "
634 << getType();
635 }
636 }
637
638 return success();
639}
640
642 auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
643 if (!launchOp)
644 return false;
645
646 Region &body = launchOp.getBody();
647 assert(!body.empty() && "Invalid region");
648
649 // Only convert ops in gpu::launch entry block for now.
650 return op->getBlock() == &body.front();
651}
652
653OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
654 if (!getUniform() && canMakeGroupOpUniform(*this)) {
655 setUniform(true);
656 return getResult();
657 }
658
659 return nullptr;
660}
661
662// TODO: Support optional custom attributes (without dialect prefix).
663static ParseResult parseAllReduceOperation(AsmParser &parser,
664 AllReduceOperationAttr &attr) {
665 StringRef enumStr;
666 if (!parser.parseOptionalKeyword(&enumStr)) {
667 std::optional<AllReduceOperation> op =
668 gpu::symbolizeAllReduceOperation(enumStr);
669 if (!op)
670 return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
671 attr = AllReduceOperationAttr::get(parser.getContext(), *op);
672 }
673 return success();
674}
675
677 AllReduceOperationAttr attr) {
678 if (attr)
679 attr.print(printer);
680}
681
682//===----------------------------------------------------------------------===//
683// SubgroupReduceOp
684//===----------------------------------------------------------------------===//
685
686LogicalResult gpu::SubgroupReduceOp::verify() {
687 Type elemType = getType();
688 if (auto vecTy = dyn_cast<VectorType>(elemType)) {
689 if (vecTy.isScalable())
690 return emitOpError() << "is not compatible with scalable vector types";
691
692 elemType = vecTy.getElementType();
693 }
694
695 gpu::AllReduceOperation opName = getOp();
696 if (failed(verifyReduceOpAndType(opName, elemType))) {
697 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
698 << "` reduction operation is not compatible with type "
699 << getType();
700 }
701
702 auto clusterSize = getClusterSize();
703 if (clusterSize) {
704 uint32_t size = *clusterSize;
705 if (!llvm::isPowerOf2_32(size)) {
706 return emitOpError() << "cluster size " << size
707 << " is not a power of two";
708 }
709 }
710
711 uint32_t stride = getClusterStride();
712 if (stride != 1 && !clusterSize) {
713 return emitOpError() << "cluster stride can only be specified if cluster "
714 "size is specified";
715 }
716 if (!llvm::isPowerOf2_32(stride)) {
717 return emitOpError() << "cluster stride " << stride
718 << " is not a power of two";
719 }
720
721 return success();
722}
723
724OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
725 if (getClusterSize() == 1)
726 return getValue();
727
728 if (!getUniform() && canMakeGroupOpUniform(*this)) {
729 setUniform(true);
730 return getResult();
731 }
732
733 return nullptr;
734}
735
736//===----------------------------------------------------------------------===//
737// AsyncOpInterface
738//===----------------------------------------------------------------------===//
739
741 op->insertOperands(0, {token});
742 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
743 return;
744 auto attrName =
746 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
747
748 // Async dependencies is the only variadic operand.
749 if (!sizeAttr)
750 return;
751
752 SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
753 ++sizes.front();
754 op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
755}
756
757//===----------------------------------------------------------------------===//
758// LaunchOp
759//===----------------------------------------------------------------------===//
760
761void LaunchOp::build(OpBuilder &builder, OperationState &result,
762 Value gridSizeX, Value gridSizeY, Value gridSizeZ,
763 Value getBlockSizeX, Value getBlockSizeY,
764 Value getBlockSizeZ, Value dynamicSharedMemorySize,
765 Type asyncTokenType, ValueRange asyncDependencies,
766 TypeRange workgroupAttributions,
767 TypeRange privateAttributions, Value clusterSizeX,
768 Value clusterSizeY, Value clusterSizeZ,
769 FlatSymbolRefAttr module, FlatSymbolRefAttr function) {
770 OpBuilder::InsertionGuard g(builder);
771
772 // Add a WorkGroup attribution attribute. This attribute is required to
773 // identify private attributions in the list of block argguments.
774 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
775 builder.getI64IntegerAttr(workgroupAttributions.size()));
776
777 // Add Op operands.
778 result.addOperands(asyncDependencies);
779 if (asyncTokenType)
780 result.types.push_back(builder.getType<AsyncTokenType>());
781
782 // Add grid and block sizes as op operands, followed by the data operands.
783 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
784 getBlockSizeY, getBlockSizeZ});
785 if (clusterSizeX)
786 result.addOperands(clusterSizeX);
787 if (clusterSizeY)
788 result.addOperands(clusterSizeY);
789 if (clusterSizeZ)
790 result.addOperands(clusterSizeZ);
791 if (dynamicSharedMemorySize)
792 result.addOperands(dynamicSharedMemorySize);
793
794 // Add optional module and function attributes.
795 if (module)
796 result.addAttribute(getModuleAttrName(result.name), module);
797 if (function)
798 result.addAttribute(getFunctionAttrName(result.name), function);
799
800 // Create a kernel body region with kNumConfigRegionAttributes + N memory
801 // attributions, where the first kNumConfigRegionAttributes arguments have
802 // `index` type and the rest have the same types as the data operands.
803 Region *kernelRegion = result.addRegion();
804 Block *body = builder.createBlock(kernelRegion);
805 // TODO: Allow passing in proper locations here.
806 for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
807 body->addArgument(builder.getIndexType(), result.location);
808 // Add WorkGroup & Private attributions to the region arguments.
809 for (Type argTy : workgroupAttributions)
810 body->addArgument(argTy, result.location);
811 for (Type argTy : privateAttributions)
812 body->addArgument(argTy, result.location);
813 // Fill OperandSegmentSize Attribute.
814 SmallVector<int32_t, 11> segmentSizes(11, 1);
815 segmentSizes.front() = asyncDependencies.size();
816 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
817 segmentSizes[7] = clusterSizeX ? 1 : 0;
818 segmentSizes[8] = clusterSizeY ? 1 : 0;
819 segmentSizes[9] = clusterSizeZ ? 1 : 0;
820 result.addAttribute(getOperandSegmentSizeAttr(),
821 builder.getDenseI32ArrayAttr(segmentSizes));
822}
823
824KernelDim3 LaunchOp::getBlockIds() {
825 assert(!getBody().empty() && "LaunchOp body must not be empty.");
826 auto args = getBody().getArguments();
827 return KernelDim3{args[0], args[1], args[2]};
828}
829
830KernelDim3 LaunchOp::getThreadIds() {
831 assert(!getBody().empty() && "LaunchOp body must not be empty.");
832 auto args = getBody().getArguments();
833 return KernelDim3{args[3], args[4], args[5]};
834}
835
836KernelDim3 LaunchOp::getGridSize() {
837 assert(!getBody().empty() && "LaunchOp body must not be empty.");
838 auto args = getBody().getArguments();
839 return KernelDim3{args[6], args[7], args[8]};
840}
841
842KernelDim3 LaunchOp::getBlockSize() {
843 assert(!getBody().empty() && "LaunchOp body must not be empty.");
844 auto args = getBody().getArguments();
845 return KernelDim3{args[9], args[10], args[11]};
846}
847
848std::optional<KernelDim3> LaunchOp::getClusterIds() {
849 assert(!getBody().empty() && "LaunchOp body must not be empty.");
850 if (!hasClusterSize())
851 return std::nullopt;
852 auto args = getBody().getArguments();
853 return KernelDim3{args[12], args[13], args[14]};
854}
855
856std::optional<KernelDim3> LaunchOp::getClusterSize() {
857 assert(!getBody().empty() && "LaunchOp body must not be empty.");
858 if (!hasClusterSize())
859 return std::nullopt;
860 auto args = getBody().getArguments();
861 return KernelDim3{args[15], args[16], args[17]};
862}
863
864KernelDim3 LaunchOp::getGridSizeOperandValues() {
865 auto operands = getOperands().drop_front(getAsyncDependencies().size());
866 return KernelDim3{operands[0], operands[1], operands[2]};
867}
868
869KernelDim3 LaunchOp::getBlockSizeOperandValues() {
870 auto operands = getOperands().drop_front(getAsyncDependencies().size());
871 return KernelDim3{operands[3], operands[4], operands[5]};
872}
873
874std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
875 auto operands = getOperands().drop_front(getAsyncDependencies().size());
876 if (!hasClusterSize())
877 return std::nullopt;
878 return KernelDim3{operands[6], operands[7], operands[8]};
879}
880
881LogicalResult LaunchOp::verify() {
882 if (!(hasClusterSize()) &&
883 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
884 return emitOpError() << "cluster size must be all present";
885 return success();
886}
887
888LogicalResult LaunchOp::verifyRegions() {
889 // Kernel launch takes kNumConfigOperands leading operands for grid/block
890 // sizes and transforms them into kNumConfigRegionAttributes region arguments
891 // for block/thread identifiers and grid/block sizes.
892 if (!getBody().empty()) {
893 if (getBody().getNumArguments() <
894 kNumConfigRegionAttributes + getNumWorkgroupAttributions())
895 return emitOpError("unexpected number of region arguments");
896 }
897
898 // Verify Attributions Address Spaces.
899 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
900 GPUDialect::getWorkgroupAddressSpace())) ||
901 failed(verifyAttributions(getOperation(), getPrivateAttributions(),
902 GPUDialect::getPrivateAddressSpace())))
903 return failure();
904
905 // Block terminators without successors are expected to exit the kernel region
906 // and must be `gpu.terminator`.
907 for (Block &block : getBody()) {
908 if (block.empty())
909 continue;
910 if (block.back().getNumSuccessors() != 0)
911 continue;
912 if (!isa<gpu::TerminatorOp>(&block.back())) {
913 return block.back()
914 .emitError()
915 .append("expected '", gpu::TerminatorOp::getOperationName(),
916 "' or a terminator with successors")
917 .attachNote(getLoc())
918 .append("in '", LaunchOp::getOperationName(), "' body region");
919 }
920 }
921
922 if (getNumResults() == 0 && getAsyncToken())
923 return emitOpError("needs to be named when async keyword is specified");
924
925 return success();
926}
927
928// Pretty-print the kernel grid/block size assignment as
929// (%iter-x, %iter-y, %iter-z) in
930// (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
931// where %size-* and %iter-* will correspond to the body region arguments.
933 KernelDim3 operands, KernelDim3 ids) {
934 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
935 p << size.x << " = " << operands.x << ", ";
936 p << size.y << " = " << operands.y << ", ";
937 p << size.z << " = " << operands.z << ')';
938}
939
940void LaunchOp::print(OpAsmPrinter &p) {
941 if (getAsyncToken()) {
942 p << " async";
943 if (!getAsyncDependencies().empty())
944 p << " [" << getAsyncDependencies() << ']';
945 }
946 // Print the launch configuration.
947 if (hasClusterSize()) {
948 p << ' ' << getClustersKeyword();
949 printSizeAssignment(p, getClusterSize().value(),
950 getClusterSizeOperandValues().value(),
951 getClusterIds().value());
952 }
953 p << ' ' << getBlocksKeyword();
954 printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
955 getBlockIds());
956 p << ' ' << getThreadsKeyword();
957 printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
958 getThreadIds());
959 if (getDynamicSharedMemorySize())
960 p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
961 << getDynamicSharedMemorySize();
962
963 // Print optional module attribute.
964 StringRef moduleAttrName = getModuleAttrName();
965 if (auto module = getModule()) {
966 p << ' ' << moduleAttrName << '(';
967 p.printSymbolName(*module);
968 p << ')';
969 }
970 // Print optional function attribute.
971 StringRef functionAttrName = getFunctionAttrName();
972 if (auto function = getFunction()) {
973 p << ' ' << functionAttrName << '(';
974 p.printSymbolName(*function);
975 p << ')';
976 }
977
978 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
979 printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
980
981 p << ' ';
982
983 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
984 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
985 LaunchOp::getOperandSegmentSizeAttr(),
986 getNumWorkgroupAttributionsAttrName(),
987 moduleAttrName, functionAttrName});
988}
989
990// Parse the size assignment blocks for blocks and threads. These have the form
991// (%region_arg, %region_arg, %region_arg) in
992// (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
993// where %region_arg are percent-identifiers for the region arguments to be
994// introduced further (SSA defs), and %operand are percent-identifiers for the
995// SSA value uses.
996static ParseResult
1001 assert(indices.size() == 3 && "space for three indices expected");
1004 /*allowResultNumber=*/false) ||
1005 parser.parseKeyword("in") || parser.parseLParen())
1006 return failure();
1007 std::move(args.begin(), args.end(), indices.begin());
1008
1009 for (int i = 0; i < 3; ++i) {
1010 if (i != 0 && parser.parseComma())
1011 return failure();
1012 if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
1013 parser.parseEqual() || parser.parseOperand(sizes[i]))
1014 return failure();
1015 }
1016
1017 return parser.parseRParen();
1018}
1019
1020/// Parses a Launch operation.
1021/// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
1022/// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
1023/// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
1024/// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
1025/// (`dynamic_shared_memory_size` ssa-use)?
1026/// (`module(` symbol-ref-id `)`)?
1027/// (`function(` symbol-ref-id `)`)?
1028/// memory-attribution
1029/// region attr-dict?
1030/// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
1031ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
1032 // Sizes of the grid and block.
1033 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
1034 sizes(LaunchOp::kNumConfigOperands);
1035
1036 // Region arguments to be created.
1037 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
1038 LaunchOp::kNumConfigRegionAttributes);
1039
1040 // Parse optional async dependencies.
1041 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
1042 Type asyncTokenType;
1043 if (failed(
1044 parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
1045 parser.resolveOperands(asyncDependencies, asyncTokenType,
1046 result.operands))
1047 return failure();
1048 if (parser.getNumResults() > 0)
1049 result.types.push_back(asyncTokenType);
1050
1051 bool hasCluster = false;
1052 if (succeeded(
1053 parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) {
1054 hasCluster = true;
1055 sizes.resize(9);
1056 regionArgs.resize(18);
1057 }
1058 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
1059 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
1060
1061 // Last three segment assigns the cluster size. In the region argument
1062 // list, this is last 6 arguments.
1063 if (hasCluster) {
1064 if (parseSizeAssignment(parser, sizesRef.drop_front(6),
1065 regionArgsRef.slice(15, 3),
1066 regionArgsRef.slice(12, 3)))
1067 return failure();
1068 }
1069 // Parse the size assignment segments: the first segment assigns grid sizes
1070 // and defines values for block identifiers; the second segment assigns block
1071 // sizes and defines values for thread identifiers. In the region argument
1072 // list, identifiers precede sizes, and block-related values precede
1073 // thread-related values.
1074 if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
1075 parseSizeAssignment(parser, sizesRef.take_front(3),
1076 regionArgsRef.slice(6, 3),
1077 regionArgsRef.slice(0, 3)) ||
1078 parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
1079 parseSizeAssignment(parser, sizesRef.drop_front(3),
1080 regionArgsRef.slice(9, 3),
1081 regionArgsRef.slice(3, 3)) ||
1082 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
1083 result.operands))
1084 return failure();
1085
1086 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
1087 bool hasDynamicSharedMemorySize = false;
1088 if (!parser.parseOptionalKeyword(
1089 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1090 hasDynamicSharedMemorySize = true;
1091 if (parser.parseOperand(dynamicSharedMemorySize) ||
1092 parser.resolveOperand(dynamicSharedMemorySize,
1093 parser.getBuilder().getI32Type(),
1094 result.operands))
1095 return failure();
1096 }
1097
1098 // Parse optional module attribute.
1099 StringRef moduleAttrName = getModuleAttrName(result.name);
1100 if (succeeded(parser.parseOptionalKeyword(moduleAttrName))) {
1101 FlatSymbolRefAttr moduleSymbol;
1102 if (parser.parseLParen() ||
1103 parser.parseAttribute(moduleSymbol, Type(), moduleAttrName,
1104 result.attributes) ||
1105 parser.parseRParen())
1106 return failure();
1107 }
1108 // Parse optional function attribute.
1109 StringRef functionAttrName = getFunctionAttrName(result.name);
1110 if (succeeded(parser.parseOptionalKeyword(functionAttrName))) {
1111 FlatSymbolRefAttr funcSymbol;
1112 if (parser.parseLParen() ||
1113 parser.parseAttribute(funcSymbol, Type(), functionAttrName,
1114 result.attributes) ||
1115 parser.parseRParen())
1116 return failure();
1117 }
1118
1119 // Create the region arguments, it has kNumConfigRegionAttributes arguments
1120 // that correspond to block/thread identifiers and grid/block sizes, all
1121 // having `index` type, a variadic number of WorkGroup Attributions and
1122 // a variadic number of Private Attributions. The number of WorkGroup
1123 // Attributions is stored in the attr with name:
1124 // LaunchOp::getNumWorkgroupAttributionsAttrName().
1125 Type index = parser.getBuilder().getIndexType();
1126 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
1127 LaunchOp::kNumConfigRegionAttributes + 6, index);
1128
1129 SmallVector<OpAsmParser::Argument> regionArguments;
1130 for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1131 OpAsmParser::Argument arg;
1132 arg.ssaName = std::get<0>(ssaValueAndType);
1133 arg.type = std::get<1>(ssaValueAndType);
1134 regionArguments.push_back(arg);
1135 }
1136
1137 Builder &builder = parser.getBuilder();
1138 // Parse workgroup memory attributions.
1139 if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
1140 regionArguments)))
1141 return failure();
1142
1143 // Store the number of operands we just parsed as the number of workgroup
1144 // memory attributions.
1145 unsigned numWorkgroupAttrs = regionArguments.size() -
1146 LaunchOp::kNumConfigRegionAttributes -
1147 (hasCluster ? 6 : 0);
1148 result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1149 builder.getI64IntegerAttr(numWorkgroupAttrs));
1150
1151 // Parse private memory attributions.
1152 if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
1153 regionArguments)))
1154 return failure();
1155
1156 // Introduce the body region and parse it. The region has
1157 // kNumConfigRegionAttributes arguments that correspond to
1158 // block/thread identifiers and grid/block sizes, all having `index` type.
1159 Region *body = result.addRegion();
1160 if (parser.parseRegion(*body, regionArguments) ||
1161 parser.parseOptionalAttrDict(result.attributes))
1162 return failure();
1163
1164 SmallVector<int32_t, 11> segmentSizes(11, 1);
1165 segmentSizes.front() = asyncDependencies.size();
1166
1167 if (!hasCluster) {
1168 segmentSizes[7] = 0;
1169 segmentSizes[8] = 0;
1170 segmentSizes[9] = 0;
1171 }
1172 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1173 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1174 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1175 return success();
1176}
1177
1178/// Simplify the gpu.launch when the range of a thread or block ID is
1179/// trivially known to be one.
1180struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1181 using OpRewritePattern<LaunchOp>::OpRewritePattern;
1182 LogicalResult matchAndRewrite(LaunchOp op,
1183 PatternRewriter &rewriter) const override {
1184 // If the range implies a single value for `id`, replace `id`'s uses by
1185 // zero.
1186 Value zero;
1187 bool simplified = false;
1188 auto constPropIdUses = [&](Value id, Value size) {
1189 // Check if size is trivially one.
1190 if (!matchPattern(size, m_One()))
1191 return;
1192 if (id.getUses().empty())
1193 return;
1194 if (!simplified) {
1195 // Create a zero value the first time.
1196 OpBuilder::InsertionGuard guard(rewriter);
1197 rewriter.setInsertionPointToStart(&op.getBody().front());
1198 zero =
1199 arith::ConstantIndexOp::create(rewriter, op.getLoc(), /*value=*/0);
1200 }
1201 rewriter.replaceAllUsesWith(id, zero);
1202 simplified = true;
1203 };
1204 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1205 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1206 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1207 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1208 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1209 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1210
1211 return success(simplified);
1212 }
1213};
1214
1215void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1216 MLIRContext *context) {
1217 rewrites.add<FoldLaunchArguments>(context);
1218}
1219
1220/// Adds a new block argument that corresponds to buffers located in
1221/// workgroup memory.
1222BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1223 auto attrName = getNumWorkgroupAttributionsAttrName();
1224 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1225 (*this)->setAttr(attrName,
1226 IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1227 return getBody().insertArgument(
1228 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1229}
1230
1231/// Adds a new block argument that corresponds to buffers located in
1232/// private memory.
1233BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1234 // Buffers on the private memory always come after buffers on the workgroup
1235 // memory.
1236 return getBody().addArgument(type, loc);
1237}
1238
1239//===----------------------------------------------------------------------===//
1240// LaunchFuncOp
1241//===----------------------------------------------------------------------===//
1242
1243void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1244 SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1245 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1246 ValueRange kernelOperands, Type asyncTokenType,
1247 ValueRange asyncDependencies,
1248 std::optional<KernelDim3> clusterSize) {
1249 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1250 "expected a symbol reference with a single nested reference");
1251 result.addOperands(asyncDependencies);
1252 if (asyncTokenType)
1253 result.types.push_back(builder.getType<AsyncTokenType>());
1254
1255 // Add grid and block sizes as op operands, followed by the data operands.
1256 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1258 if (clusterSize.has_value())
1259 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1260 if (dynamicSharedMemorySize)
1261 result.addOperands(dynamicSharedMemorySize);
1262 result.addOperands(kernelOperands);
1263
1264 Properties &prop = result.getOrAddProperties<Properties>();
1265 prop.kernel = kernelSymbol;
1266 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1267 // Initialize the segment sizes to 1.
1268 llvm::fill(prop.operandSegmentSizes, 1);
1269 prop.operandSegmentSizes[0] = asyncDependencies.size();
1270 if (!clusterSize.has_value()) {
1271 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1272 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1273 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1274 }
1275 prop.operandSegmentSizes[segmentSizesLen - 3] =
1276 dynamicSharedMemorySize ? 1 : 0;
1277 prop.operandSegmentSizes[segmentSizesLen - 2] =
1278 static_cast<int32_t>(kernelOperands.size());
1279 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1280}
1281
1282void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1283 GPUFuncOp kernelFunc, KernelDim3 gridSize,
1284 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1285 ValueRange kernelOperands, Type asyncTokenType,
1286 ValueRange asyncDependencies,
1287 std::optional<KernelDim3> clusterSize) {
1288 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1289 auto kernelSymbol =
1290 SymbolRefAttr::get(kernelModule.getNameAttr(),
1291 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1292 build(builder, result, kernelSymbol, gridSize, getBlockSize,
1293 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1294 asyncDependencies, clusterSize);
1295}
1296
1297void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1298 SymbolRefAttr kernel, KernelDim3 gridSize,
1299 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1300 ValueRange kernelOperands, Value asyncObject,
1301 std::optional<KernelDim3> clusterSize) {
1302 // Add grid and block sizes as op operands, followed by the data operands.
1303 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1305 if (clusterSize.has_value())
1306 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1307 if (dynamicSharedMemorySize)
1308 result.addOperands(dynamicSharedMemorySize);
1309 result.addOperands(kernelOperands);
1310 if (asyncObject)
1311 result.addOperands(asyncObject);
1312 Properties &prop = result.getOrAddProperties<Properties>();
1313 prop.kernel = kernel;
1314 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1315 // Initialize the segment sizes to 1.
1316 llvm::fill(prop.operandSegmentSizes, 1);
1317 prop.operandSegmentSizes[0] = 0;
1318 if (!clusterSize.has_value()) {
1319 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1320 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1321 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1322 }
1323 prop.operandSegmentSizes[segmentSizesLen - 3] =
1324 dynamicSharedMemorySize ? 1 : 0;
1325 prop.operandSegmentSizes[segmentSizesLen - 2] =
1326 static_cast<int32_t>(kernelOperands.size());
1327 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1328}
1329
1330StringAttr LaunchFuncOp::getKernelModuleName() {
1331 return getKernel().getRootReference();
1332}
1333
1334StringAttr LaunchFuncOp::getKernelName() {
1335 return getKernel().getLeafReference();
1336}
1337
1338unsigned LaunchFuncOp::getNumKernelOperands() {
1339 return getKernelOperands().size();
1340}
1341
1342Value LaunchFuncOp::getKernelOperand(unsigned i) {
1343 return getKernelOperands()[i];
1344}
1345
1346KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1347 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1348 return KernelDim3{operands[0], operands[1], operands[2]};
1349}
1350
1351KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1352 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1353 return KernelDim3{operands[3], operands[4], operands[5]};
1354}
1355
1356KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1357 assert(hasClusterSize() &&
1358 "cluster size is not set, check hasClusterSize() first");
1359 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1360 return KernelDim3{operands[6], operands[7], operands[8]};
1361}
1362
1363LogicalResult LaunchFuncOp::verify() {
1364 auto module = (*this)->getParentOfType<ModuleOp>();
1365 if (!module)
1366 return emitOpError("expected to belong to a module");
1367
1368 if (!module->getAttrOfType<UnitAttr>(
1369 GPUDialect::getContainerModuleAttrName()))
1370 return emitOpError("expected the closest surrounding module to have the '" +
1371 GPUDialect::getContainerModuleAttrName() +
1372 "' attribute");
1373
1374 if (hasClusterSize()) {
1375 if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1376 getClusterSizeZ().getType() != getClusterSizeX().getType())
1377 return emitOpError()
1378 << "expects types of the cluster dimensions must be the same";
1379 }
1380
1381 return success();
1382}
1383
1384static ParseResult
1386 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1387 Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1388 if (succeeded(parser.parseOptionalColon())) {
1389 if (parser.parseType(dimTy))
1390 return failure();
1391 } else {
1392 dimTy = IndexType::get(parser.getContext());
1393 }
1394 if (clusterValue.has_value()) {
1395 clusterXTy = clusterYTy = clusterZTy = dimTy;
1396 }
1397 return success();
1398}
1399
1400static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1401 Value clusterValue, Type clusterXTy,
1402 Type clusterYTy, Type clusterZTy) {
1403 if (!dimTy.isIndex())
1404 printer << ": " << dimTy;
1405}
1406
1407static ParseResult parseLaunchFuncOperands(
1408 OpAsmParser &parser,
1410 SmallVectorImpl<Type> &argTypes) {
1411 if (parser.parseOptionalKeyword("args"))
1412 return success();
1413
1414 auto parseElement = [&]() -> ParseResult {
1415 return failure(parser.parseOperand(argNames.emplace_back()) ||
1416 parser.parseColonType(argTypes.emplace_back()));
1417 };
1418
1420 parseElement, " in argument list");
1421}
1422
1424 OperandRange operands, TypeRange types) {
1425 if (operands.empty())
1426 return;
1427 printer << "args(";
1428 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1429 [&](const auto &pair) {
1430 auto [operand, type] = pair;
1431 printer << operand << " : " << type;
1432 });
1433 printer << ")";
1434}
1435
1436//===----------------------------------------------------------------------===//
1437// ShuffleOp
1438//===----------------------------------------------------------------------===//
1439
1440void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1441 int32_t offset, int32_t width, ShuffleMode mode) {
1442 build(builder, result, value,
1443 arith::ConstantOp::create(builder, result.location,
1444 builder.getI32IntegerAttr(offset)),
1445 arith::ConstantOp::create(builder, result.location,
1446 builder.getI32IntegerAttr(width)),
1447 mode);
1448}
1449
1450//===----------------------------------------------------------------------===//
1451// RotateOp
1452//===----------------------------------------------------------------------===//
1453
1454LogicalResult RotateOp::verify() {
1455 uint32_t offset = getOffset();
1456 uint32_t width = getWidth();
1457
1458 if (offset >= width) {
1459 return emitOpError() << "offset must be in the range [0, " << width << ")";
1460 }
1461
1462 return success();
1463}
1464
1465//===----------------------------------------------------------------------===//
1466// BarrierOp
1467//===----------------------------------------------------------------------===//
1468
1469namespace {
1470
1471/// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1472LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1473 PatternRewriter &rewriter) {
1474 if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1475 rewriter.eraseOp(op);
1476 return success();
1477 }
1478 return failure();
1479}
1480
1481} // end anonymous namespace
1482
1483void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1484 MLIRContext *context) {
1485 results.add(eraseRedundantGpuBarrierOps);
1486}
1487
1488//===----------------------------------------------------------------------===//
1489// GPUFuncOp
1490//===----------------------------------------------------------------------===//
1491
1492/// Adds a new block argument that corresponds to buffers located in
1493/// workgroup memory.
1494BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1495 auto attrName = getNumWorkgroupAttributionsAttrName();
1496 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1497 (*this)->setAttr(attrName,
1498 IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1499 return getBody().insertArgument(
1500 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1501}
1502
1503/// Adds a new block argument that corresponds to buffers located in
1504/// private memory.
1505BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1506 // Buffers on the private memory always come after buffers on the workgroup
1507 // memory.
1508 return getBody().addArgument(type, loc);
1509}
1510
1511void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1512 StringRef name, FunctionType type,
1513 TypeRange workgroupAttributions,
1514 TypeRange privateAttributions,
1515 ArrayRef<NamedAttribute> attrs) {
1516 OpBuilder::InsertionGuard g(builder);
1517
1519 builder.getStringAttr(name));
1520 result.addAttribute(getFunctionTypeAttrName(result.name),
1521 TypeAttr::get(type));
1522 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1523 builder.getI64IntegerAttr(workgroupAttributions.size()));
1524 result.addAttributes(attrs);
1525 Region *body = result.addRegion();
1526 Block *entryBlock = builder.createBlock(body);
1527
1528 // TODO: Allow passing in proper locations here.
1529 for (Type argTy : type.getInputs())
1530 entryBlock->addArgument(argTy, result.location);
1531 for (Type argTy : workgroupAttributions)
1532 entryBlock->addArgument(argTy, result.location);
1533 for (Type argTy : privateAttributions)
1534 entryBlock->addArgument(argTy, result.location);
1535}
1536
1537/// Parses a GPU function memory attribution.
1538///
1539/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1540/// (`private` `(` ssa-id-and-type-list `)`)?
1541///
1542/// Note that this function parses only one of the two similar parts, with the
1543/// keyword provided as argument.
1544static ParseResult
1545parseAttributions(OpAsmParser &parser, StringRef keyword,
1547 Attribute &attributionAttrs) {
1548 // If we could not parse the keyword, just assume empty list and succeed.
1549 if (failed(parser.parseOptionalKeyword(keyword)))
1550 return success();
1551
1552 size_t existingArgs = args.size();
1553 ParseResult result =
1555 /*allowType=*/true, /*allowAttrs=*/true);
1556 if (failed(result))
1557 return result;
1558
1559 bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1560 [](const OpAsmParser::Argument &arg) -> bool {
1561 return arg.attrs && !arg.attrs.empty();
1562 });
1563 if (!hadAttrs) {
1564 attributionAttrs = nullptr;
1565 return result;
1566 }
1567
1568 Builder &builder = parser.getBuilder();
1569 SmallVector<Attribute> attributionAttrsVec;
1570 for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1571 if (!argument.attrs)
1572 attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1573 else
1574 attributionAttrsVec.push_back(argument.attrs);
1575 }
1576 attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1577 return result;
1578}
1579
1580/// Parses a GPU function.
1581///
1582/// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1583/// (`->` function-result-list)? memory-attribution `kernel`?
1584/// function-attributes? region
1585ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
1586 SmallVector<OpAsmParser::Argument> entryArgs;
1587 SmallVector<DictionaryAttr> resultAttrs;
1588 SmallVector<Type> resultTypes;
1589 bool isVariadic;
1590
1591 // Parse the function name.
1592 StringAttr nameAttr;
1594 result.attributes))
1595 return failure();
1596
1597 auto signatureLocation = parser.getCurrentLocation();
1599 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1600 resultAttrs)))
1601 return failure();
1602
1603 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1604 return parser.emitError(signatureLocation)
1605 << "gpu.func requires named arguments";
1606
1607 // Construct the function type. More types will be added to the region, but
1608 // not to the function type.
1609 Builder &builder = parser.getBuilder();
1610
1611 SmallVector<Type> argTypes;
1612 for (auto &arg : entryArgs)
1613 argTypes.push_back(arg.type);
1614 auto type = builder.getFunctionType(argTypes, resultTypes);
1615 result.addAttribute(getFunctionTypeAttrName(result.name),
1616 TypeAttr::get(type));
1617
1619 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1620 getResAttrsAttrName(result.name));
1621
1622 Attribute workgroupAttributionAttrs;
1623 // Parse workgroup memory attributions.
1624 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1625 entryArgs, workgroupAttributionAttrs)))
1626 return failure();
1627
1628 // Store the number of operands we just parsed as the number of workgroup
1629 // memory attributions.
1630 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1631 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1632 builder.getI64IntegerAttr(numWorkgroupAttrs));
1633 if (workgroupAttributionAttrs)
1634 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1635 workgroupAttributionAttrs);
1636
1637 Attribute privateAttributionAttrs;
1638 // Parse private memory attributions.
1639 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1640 entryArgs, privateAttributionAttrs)))
1641 return failure();
1642 if (privateAttributionAttrs)
1643 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1644 privateAttributionAttrs);
1645
1646 // Parse the kernel attribute if present.
1647 if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1648 result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1649 builder.getUnitAttr());
1650
1651 // Parse attributes.
1652 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1653 return failure();
1654
1655 // Parse the region. If no argument names were provided, take all names
1656 // (including those of attributions) from the entry block.
1657 auto *body = result.addRegion();
1658 return parser.parseRegion(*body, entryArgs);
1659}
1660
1661void GPUFuncOp::print(OpAsmPrinter &p) {
1662 p << ' ';
1663 p.printSymbolName(getName());
1664
1665 FunctionType type = getFunctionType();
1666 function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1667 /*isVariadic=*/false,
1668 type.getResults());
1669
1670 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1671 getWorkgroupAttribAttrs().value_or(nullptr));
1672 printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1673 getPrivateAttribAttrs().value_or(nullptr));
1674 if (isKernel())
1675 p << ' ' << getKernelKeyword();
1676
1678 p, *this,
1679 {getNumWorkgroupAttributionsAttrName(),
1680 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1681 getArgAttrsAttrName(), getResAttrsAttrName(),
1682 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1683 p << ' ';
1684 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1685}
1686
1687static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1688 StringAttr attrName) {
1689 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1690 if (!allAttrs || index >= allAttrs.size())
1691 return DictionaryAttr();
1692 return llvm::cast<DictionaryAttr>(allAttrs[index]);
1693}
1694
1695DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1696 return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1697}
1698
1699DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1700 return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1701}
1702
1703static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1704 DictionaryAttr value, StringAttr attrName) {
1705 MLIRContext *ctx = op.getContext();
1706 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1707 SmallVector<Attribute> elements;
1708 if (allAttrs)
1709 elements.append(allAttrs.begin(), allAttrs.end());
1710 while (elements.size() <= index)
1711 elements.push_back(DictionaryAttr::get(ctx));
1712 if (!value)
1713 elements[index] = DictionaryAttr::get(ctx);
1714 else
1715 elements[index] = value;
1716 ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1717 op->setAttr(attrName, newValue);
1718}
1719
1720void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1721 DictionaryAttr value) {
1722 setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1723}
1724
1725void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1726 DictionaryAttr value) {
1727 setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1728}
1729
1730static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1731 StringAttr name, StringAttr attrsName) {
1732 DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1733 if (!dict)
1734 return Attribute();
1735 return dict.get(name);
1736}
1737
1738Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1739 StringAttr name) {
1740 assert(index < getNumWorkgroupAttributions() &&
1741 "index must map to a workgroup attribution");
1742 return getAttributionAttr(*this, index, name,
1743 getWorkgroupAttribAttrsAttrName());
1744}
1745
1746Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1747 StringAttr name) {
1748 assert(index < getNumPrivateAttributions() &&
1749 "index must map to a private attribution");
1750 return getAttributionAttr(*this, index, name,
1751 getPrivateAttribAttrsAttrName());
1752}
1753
1754static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1755 Attribute value, StringAttr attrsName) {
1756 MLIRContext *ctx = op.getContext();
1758 DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1759 if (oldDict)
1760 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1761
1762 bool found = false;
1763 bool mustSort = true;
1764 for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1765 if (elems[i].getName() == name) {
1766 found = true;
1767 if (!value) {
1768 std::swap(elems[i], elems[elems.size() - 1]);
1769 elems.pop_back();
1770 } else {
1771 mustSort = false;
1772 elems[i] = NamedAttribute(elems[i].getName(), value);
1773 }
1774 break;
1775 }
1776 }
1777 if (!found) {
1778 if (!value)
1779 return;
1780 elems.emplace_back(name, value);
1781 }
1782 if (mustSort) {
1783 DictionaryAttr::sortInPlace(elems);
1784 }
1785 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1786 setAttributionAttrs(op, index, newDict, attrsName);
1787}
1788
1789void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1790 Attribute value) {
1791 assert(index < getNumWorkgroupAttributions() &&
1792 "index must map to a workgroup attribution");
1793 setAttributionAttr(*this, index, name, value,
1794 getWorkgroupAttribAttrsAttrName());
1795}
1796
1797void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1798 Attribute value) {
1799 assert(index < getNumPrivateAttributions() &&
1800 "index must map to a private attribution");
1801 setAttributionAttr(*this, index, name, value,
1802 getPrivateAttribAttrsAttrName());
1803}
1804
1805LogicalResult GPUFuncOp::verifyType() {
1806 if (isKernel() && getFunctionType().getNumResults() != 0)
1807 return emitOpError() << "expected void return type for kernel function";
1808
1809 return success();
1810}
1811
1812/// Verifies the body of the function.
1813LogicalResult GPUFuncOp::verifyBody() {
1814 if (empty())
1815 return emitOpError() << "expected body with at least one block";
1816 unsigned numFuncArguments = getNumArguments();
1817 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1818 unsigned numBlockArguments = front().getNumArguments();
1819 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1820 return emitOpError() << "expected at least "
1821 << numFuncArguments + numWorkgroupAttributions
1822 << " arguments to body region";
1823
1824 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1825 for (unsigned i = 0; i < numFuncArguments; ++i) {
1826 Type blockArgType = front().getArgument(i).getType();
1827 if (funcArgTypes[i] != blockArgType)
1828 return emitOpError() << "expected body region argument #" << i
1829 << " to be of type " << funcArgTypes[i] << ", got "
1830 << blockArgType;
1831 }
1832
1833 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1834 GPUDialect::getWorkgroupAddressSpace())) ||
1835 failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1836 GPUDialect::getPrivateAddressSpace())))
1837 return failure();
1838
1839 return success();
1840}
1841
1842//===----------------------------------------------------------------------===//
1843// ReturnOp
1844//===----------------------------------------------------------------------===//
1845
1846LogicalResult gpu::ReturnOp::verify() {
1847 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1848
1849 FunctionType funType = function.getFunctionType();
1850
1851 if (funType.getNumResults() != getOperands().size())
1852 return emitOpError()
1853 .append("expected ", funType.getNumResults(), " result operands")
1854 .attachNote(function.getLoc())
1855 .append("return type declared here");
1856
1857 for (const auto &pair : llvm::enumerate(
1858 llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1859 auto [type, operand] = pair.value();
1860 if (type != operand.getType())
1861 return emitOpError() << "unexpected type `" << operand.getType()
1862 << "' for operand #" << pair.index();
1863 }
1864 return success();
1865}
1866
1867//===----------------------------------------------------------------------===//
1868// GPUModuleOp
1869//===----------------------------------------------------------------------===//
1870
1871void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1872 StringRef name, ArrayAttr targets,
1873 Attribute offloadingHandler) {
1874 result.addRegion()->emplaceBlock();
1875 Properties &props = result.getOrAddProperties<Properties>();
1876 if (targets)
1877 props.targets = targets;
1878 props.setSymName(builder.getStringAttr(name));
1879 props.offloadingHandler = offloadingHandler;
1880}
1881
1882void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1883 StringRef name, ArrayRef<Attribute> targets,
1884 Attribute offloadingHandler) {
1885 build(builder, result, name,
1886 targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1887 offloadingHandler);
1888}
1889
1890bool GPUModuleOp::hasTarget(Attribute target) {
1891 if (ArrayAttr targets = getTargetsAttr())
1892 return llvm::count(targets.getValue(), target);
1893 return false;
1894}
1895
1896void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1897 ArrayAttr &targetsAttr = getProperties().targets;
1898 SmallVector<Attribute> targetsVector(targets);
1899 targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1900}
1901
1902LogicalResult GPUModuleOp::verify() {
1903 auto targets = getOperation()->getAttrOfType<ArrayAttr>("targets");
1904
1905 if (!targets)
1906 return success();
1907
1908 for (auto target : targets) {
1909 if (auto verifyTargetAttr =
1910 llvm::dyn_cast<TargetAttrVerifyInterface>(target)) {
1911 if (verifyTargetAttr.verifyTarget(getOperation()).failed())
1912 return failure();
1913 }
1914 }
1915 return success();
1916}
1917
1918//===----------------------------------------------------------------------===//
1919// GPUBinaryOp
1920//===----------------------------------------------------------------------===//
1921void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1922 Attribute offloadingHandler, ArrayAttr objects) {
1923 auto &properties = result.getOrAddProperties<Properties>();
1924 result.attributes.push_back(builder.getNamedAttr(
1926 properties.objects = objects;
1927 if (offloadingHandler)
1928 properties.offloadingHandler = offloadingHandler;
1929 else
1930 properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
1931}
1932
1933void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1934 Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1935 build(builder, result, name, offloadingHandler,
1936 objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
1937}
1938
1939static ParseResult parseOffloadingHandler(OpAsmParser &parser,
1940 Attribute &offloadingHandler) {
1941 if (succeeded(parser.parseOptionalLess())) {
1942 if (parser.parseAttribute(offloadingHandler))
1943 return failure();
1944 if (parser.parseGreater())
1945 return failure();
1946 }
1947 if (!offloadingHandler)
1948 offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
1949 return success();
1950}
1951
1953 Attribute offloadingHandler) {
1954 if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
1955 printer << '<' << offloadingHandler << '>';
1956}
1957
1958//===----------------------------------------------------------------------===//
1959// GPUMemcpyOp
1960//===----------------------------------------------------------------------===//
1961
1962LogicalResult MemcpyOp::verify() {
1963 auto srcType = getSrc().getType();
1964 auto dstType = getDst().getType();
1965
1966 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1967 return emitOpError("arguments have incompatible element type");
1968
1969 if (failed(verifyCompatibleShape(srcType, dstType)))
1970 return emitOpError("arguments have incompatible shape");
1971
1972 return success();
1973}
1974
1975namespace {
1976
1977/// Erases a common case of copy ops where a destination value is used only by
1978/// the copy op, alloc and dealloc ops.
1979struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1980 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1981
1982 LogicalResult matchAndRewrite(MemcpyOp op,
1983 PatternRewriter &rewriter) const override {
1984 Value dest = op.getDst();
1985 Operation *destDefOp = dest.getDefiningOp();
1986 // `dest` must be defined by an op having Allocate memory effect in order to
1987 // perform the folding.
1988 if (!destDefOp ||
1990 return failure();
1991 // We can erase `op` iff `dest` has no other use apart from its
1992 // use by `op` and dealloc ops.
1993 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1994 return user != op &&
1995 !hasSingleEffect<MemoryEffects::Free>(user, dest);
1996 }))
1997 return failure();
1998 // We can perform the folding if and only if op has a single async
1999 // dependency and produces an async token as result, or if it does not have
2000 // any async dependency and does not produce any async token result.
2001 if (op.getAsyncDependencies().size() > 1 ||
2002 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2003 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2004 return failure();
2005 rewriter.replaceOp(op, op.getAsyncDependencies());
2006 return success();
2007 }
2008};
2009
2010} // end anonymous namespace
2011
2012void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2013 MLIRContext *context) {
2014 results.add<EraseTrivialCopyOp>(context);
2015}
2016
2017//===----------------------------------------------------------------------===//
2018// GPU_SubgroupMmaLoadMatrixOp
2019//===----------------------------------------------------------------------===//
2020
2021LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2022 auto srcType = getSrcMemref().getType();
2023 auto resType = getRes().getType();
2024 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2025 auto operand = resMatrixType.getOperand();
2026 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2027
2028 if (!srcMemrefType.isLastDimUnitStride())
2029 return emitError(
2030 "expected source memref most minor dim must have unit stride");
2031
2032 if (operand != "AOp" && operand != "BOp" && operand != "COp")
2033 return emitError("only AOp, BOp and COp can be loaded");
2034
2035 return success();
2036}
2037
2038//===----------------------------------------------------------------------===//
2039// GPU_SubgroupMmaStoreMatrixOp
2040//===----------------------------------------------------------------------===//
2041
2042LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2043 auto srcType = getSrc().getType();
2044 auto dstType = getDstMemref().getType();
2045 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2046 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2047
2048 if (!dstMemrefType.isLastDimUnitStride())
2049 return emitError(
2050 "expected destination memref most minor dim must have unit stride");
2051
2052 if (srcMatrixType.getOperand() != "COp")
2053 return emitError(
2054 "expected the operand matrix being stored to have 'COp' operand type");
2055
2056 return success();
2057}
2058
2059//===----------------------------------------------------------------------===//
2060// GPU_SubgroupMmaComputeOp
2061//===----------------------------------------------------------------------===//
2062
2063LogicalResult SubgroupMmaComputeOp::verify() {
2064 enum OperandMap { A, B, C };
2065 SmallVector<MMAMatrixType, 3> opTypes;
2066 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
2067 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
2068 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
2069
2070 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
2071 opTypes[C].getOperand() != "COp")
2072 return emitError("operands must be in the order AOp, BOp, COp");
2073
2074 ArrayRef<int64_t> aShape, bShape, cShape;
2075 aShape = opTypes[A].getShape();
2076 bShape = opTypes[B].getShape();
2077 cShape = opTypes[C].getShape();
2078
2079 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2080 bShape[1] != cShape[1])
2081 return emitError("operand shapes do not satisfy matmul constraints");
2082
2083 return success();
2084}
2085
2086LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2087 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2088 return memref::foldMemRefCast(*this);
2089}
2090
2091LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2092 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2093 return memref::foldMemRefCast(*this);
2094}
2095
2096//===----------------------------------------------------------------------===//
2097// GPU_WaitOp
2098//===----------------------------------------------------------------------===//
2099
2100namespace {
2101
2102/// Remove gpu.wait op use of gpu.wait op def without async dependencies.
2103/// %t = gpu.wait async [] // No async dependencies.
2104/// ... gpu.wait ... [%t, ...] // %t can be removed.
2105struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2106public:
2108
2109 LogicalResult matchAndRewrite(WaitOp op,
2110 PatternRewriter &rewriter) const final {
2111 auto predicate = [](Value value) {
2112 auto waitOp = value.getDefiningOp<WaitOp>();
2113 return waitOp && waitOp->getNumOperands() == 0;
2114 };
2115 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2116 return failure();
2117 SmallVector<Value> validOperands;
2118 for (Value operand : op->getOperands()) {
2119 if (predicate(operand))
2120 continue;
2121 validOperands.push_back(operand);
2122 }
2123 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2124 return success();
2125 }
2126};
2127
2128/// Simplify trivial gpu.wait ops for the following patterns.
2129/// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2130/// dependencies).
2131/// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2132/// %t0.
2133/// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2134/// dependencies nor return any token.
2135struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2136public:
2138
2139 LogicalResult matchAndRewrite(WaitOp op,
2140 PatternRewriter &rewriter) const final {
2141 // Erase gpu.wait ops that neither have any async dependencies nor return
2142 // any async token.
2143 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2144 rewriter.eraseOp(op);
2145 return success();
2146 }
2147 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2148 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2149 op.getAsyncToken()) {
2150 rewriter.replaceOp(op, op.getAsyncDependencies());
2151 return success();
2152 }
2153 // Erase %t = gpu.wait async ... ops, where %t has no uses.
2154 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2155 rewriter.eraseOp(op);
2156 return success();
2157 }
2158 return failure();
2159 }
2160};
2161
2162} // end anonymous namespace
2163
2164void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2165 MLIRContext *context) {
2166 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2167}
2168
2169//===----------------------------------------------------------------------===//
2170// GPU_AllocOp
2171//===----------------------------------------------------------------------===//
2172
2173LogicalResult AllocOp::verify() {
2174 auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2175
2176 if (getDynamicSizes().size() != memRefType.getNumDynamicDims())
2177 return emitOpError("dimension operand count does not equal memref "
2178 "dynamic dimension count");
2179
2180 unsigned numSymbols = 0;
2181 if (!memRefType.getLayout().isIdentity())
2182 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2183 if (getSymbolOperands().size() != numSymbols) {
2184 return emitOpError(
2185 "symbol operand count does not equal memref symbol count");
2186 }
2187
2188 return success();
2189}
2190
2191namespace {
2192
2193/// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2194/// `memref::AllocOp`.
2195struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2196 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2197
2198 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2199 PatternRewriter &rewriter) const override {
2200 std::optional<int64_t> index = dimOp.getConstantIndex();
2201 if (!index)
2202 return failure();
2203
2204 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2205 if (!memrefType || index.value() >= memrefType.getRank() ||
2206 !memrefType.isDynamicDim(index.value()))
2207 return failure();
2208
2209 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2210 if (!alloc)
2211 return failure();
2212
2213 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2214 memrefType.getDynamicDimIndex(index.value()));
2215 rewriter.replaceOp(dimOp, substituteOp);
2216 return success();
2217 }
2218};
2219
2220} // namespace
2221
2222void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2223 MLIRContext *context) {
2224 results.add<SimplifyDimOfAllocOp>(context);
2225}
2226
2227//===----------------------------------------------------------------------===//
2228// GPU object attribute
2229//===----------------------------------------------------------------------===//
2230
2231LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2232 Attribute target, CompilationTarget format,
2233 StringAttr object, DictionaryAttr properties,
2234 KernelTableAttr kernels) {
2235 if (!target)
2236 return emitError() << "the target attribute cannot be null";
2237 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2238 return success();
2239 return emitError() << "the target attribute must implement or promise the "
2240 "`gpu::TargetAttrInterface`";
2241}
2242
2243namespace {
2244ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2245 StringAttr &object) {
2246 std::optional<CompilationTarget> formatResult;
2247 StringRef enumKeyword;
2248 auto loc = odsParser.getCurrentLocation();
2249 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2250 formatResult = CompilationTarget::Fatbin;
2251 if (!formatResult &&
2252 (formatResult =
2253 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2254 odsParser.parseEqual())
2255 return odsParser.emitError(loc, "expected an equal sign");
2256 if (!formatResult)
2257 return odsParser.emitError(loc, "expected keyword for GPU object format");
2258 FailureOr<StringAttr> objectResult =
2259 FieldParser<StringAttr>::parse(odsParser);
2260 if (failed(objectResult))
2261 return odsParser.emitError(odsParser.getCurrentLocation(),
2262 "failed to parse GPU_ObjectAttr parameter "
2263 "'object' which is to be a `StringAttr`");
2264 format = *formatResult;
2265 object = *objectResult;
2266 return success();
2267}
2268
2269void printObject(AsmPrinter &odsParser, CompilationTarget format,
2270 StringAttr object) {
2271 if (format != CompilationTarget::Fatbin)
2272 odsParser << stringifyEnum(format) << " = ";
2273 odsParser << object;
2274}
2275} // namespace
2276
2277//===----------------------------------------------------------------------===//
2278// GPU select object attribute
2279//===----------------------------------------------------------------------===//
2280
2281LogicalResult
2282gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2283 Attribute target) {
2284 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2285 if (target) {
2286 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2287 if (intAttr.getInt() < 0) {
2288 return emitError() << "the object index must be positive";
2289 }
2290 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2291 return emitError()
2292 << "the target attribute must be a GPU Target attribute";
2293 }
2294 }
2295 return success();
2296}
2297
2298//===----------------------------------------------------------------------===//
2299// DynamicSharedMemoryOp
2300//===----------------------------------------------------------------------===//
2301
2302LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2303 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2304 return emitOpError() << "must be inside an op with symbol table";
2305
2306 MemRefType memrefType = getResultMemref().getType();
2307 // Check address space
2308 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2309 return emitOpError() << "address space must be "
2310 << gpu::AddressSpaceAttr::getMnemonic() << "<"
2311 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2312 }
2313 if (memrefType.hasStaticShape()) {
2314 return emitOpError() << "result memref type must be memref<?xi8, "
2315 "#gpu.address_space<workgroup>>";
2316 }
2317 return success();
2318}
2319
2320//===----------------------------------------------------------------------===//
2321// GPU WarpExecuteOnLane0Op
2322//===----------------------------------------------------------------------===//
2323
2324void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2325 p << "(" << getLaneid() << ")";
2326
2327 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2328 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2329 p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2330
2331 if (!getArgs().empty())
2332 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2333 if (!getResults().empty())
2334 p << " -> (" << getResults().getTypes() << ')';
2335 p << " ";
2336 p.printRegion(getRegion(),
2337 /*printEntryBlockArgs=*/true,
2338 /*printBlockTerminators=*/!getResults().empty());
2339 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2340}
2341
2342ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2343 OperationState &result) {
2344 // Create the region.
2345 result.regions.reserve(1);
2346 Region *warpRegion = result.addRegion();
2347
2348 auto &builder = parser.getBuilder();
2349 OpAsmParser::UnresolvedOperand laneId;
2350
2351 // Parse predicate operand.
2352 if (parser.parseLParen() ||
2353 parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2354 parser.parseRParen())
2355 return failure();
2356
2357 int64_t warpSize;
2358 if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2359 parser.parseRSquare())
2360 return failure();
2361 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2362 builder.getContext())),
2363 builder.getI64IntegerAttr(warpSize));
2364
2365 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2366 return failure();
2367
2368 llvm::SMLoc inputsOperandsLoc;
2369 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2370 SmallVector<Type> inputTypes;
2371 if (succeeded(parser.parseOptionalKeyword("args"))) {
2372 if (parser.parseLParen())
2373 return failure();
2374
2375 inputsOperandsLoc = parser.getCurrentLocation();
2376 if (parser.parseOperandList(inputsOperands) ||
2377 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2378 return failure();
2379 }
2380 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2381 result.operands))
2382 return failure();
2383
2384 // Parse optional results type list.
2385 if (parser.parseOptionalArrowTypeList(result.types))
2386 return failure();
2387 // Parse the region.
2388 if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2389 /*argTypes=*/{}))
2390 return failure();
2391 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2392
2393 // Parse the optional attribute list.
2394 if (parser.parseOptionalAttrDict(result.attributes))
2395 return failure();
2396 return success();
2397}
2398
2399void WarpExecuteOnLane0Op::getSuccessorRegions(
2400 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2401 if (!point.isParent()) {
2402 regions.push_back(RegionSuccessor(getOperation(), getResults()));
2403 return;
2404 }
2405
2406 // The warp region is always executed
2407 regions.push_back(RegionSuccessor(&getWarpRegion()));
2408}
2409
2410void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2411 TypeRange resultTypes, Value laneId,
2412 int64_t warpSize) {
2413 build(builder, result, resultTypes, laneId, warpSize,
2414 /*operands=*/{}, /*argTypes=*/{});
2415}
2416
2417void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2418 TypeRange resultTypes, Value laneId,
2419 int64_t warpSize, ValueRange args,
2420 TypeRange blockArgTypes) {
2421 result.addOperands(laneId);
2422 result.addAttribute(getAttributeNames()[0],
2423 builder.getI64IntegerAttr(warpSize));
2424 result.addTypes(resultTypes);
2425 result.addOperands(args);
2426 assert(args.size() == blockArgTypes.size());
2427 OpBuilder::InsertionGuard guard(builder);
2428 Region *warpRegion = result.addRegion();
2429 Block *block = builder.createBlock(warpRegion);
2430 for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2431 block->addArgument(type, arg.getLoc());
2432}
2433
2434/// Helper check if the distributed vector type is consistent with the expanded
2435/// type and distributed size.
2436static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2437 int64_t warpSize, Operation *op) {
2438 // If the types matches there is no distribution.
2439 if (expanded == distributed)
2440 return success();
2441 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2442 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2443 if (!expandedVecType || !distributedVecType)
2444 return op->emitOpError("expected vector type for distributed operands.");
2445 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2446 expandedVecType.getElementType() != distributedVecType.getElementType())
2447 return op->emitOpError(
2448 "expected distributed vectors to have same rank and element type.");
2449
2450 SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2451 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2452 int64_t eDim = expandedVecType.getDimSize(i);
2453 int64_t dDim = distributedVecType.getDimSize(i);
2454 if (eDim == dDim)
2455 continue;
2456 if (eDim % dDim != 0)
2457 return op->emitOpError()
2458 << "expected expanded vector dimension #" << i << " (" << eDim
2459 << ") to be a multipler of the distributed vector dimension ("
2460 << dDim << ")";
2461 scales[i] = eDim / dDim;
2462 }
2463 if (llvm::product_of(scales) != warpSize)
2464 return op->emitOpError()
2465 << "incompatible distribution dimensions from " << expandedVecType
2466 << " to " << distributedVecType << " with warp size = " << warpSize;
2467
2468 return success();
2469}
2470
2471LogicalResult WarpExecuteOnLane0Op::verify() {
2472 if (getArgs().size() != getWarpRegion().getNumArguments())
2473 return emitOpError(
2474 "expected same number op arguments and block arguments.");
2475 gpu::YieldOp yield = getTerminator();
2476 if (yield.getNumOperands() != getNumResults())
2477 return emitOpError(
2478 "expected same number of yield operands and return values.");
2479 int64_t warpSize = getWarpSize();
2480 for (auto [regionArg, arg] :
2481 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2482 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2483 warpSize, getOperation())))
2484 return failure();
2485 }
2486 for (auto [yieldOperand, result] :
2487 llvm::zip_equal(yield.getOperands(), getResults())) {
2488 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2489 warpSize, getOperation())))
2490 return failure();
2491 }
2492 return success();
2493}
2494bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2495 return succeeded(
2496 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2497}
2498
2499gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2500 return cast<gpu::YieldOp>(getBody()->getTerminator());
2501}
2502
2503//===----------------------------------------------------------------------===//
2504// GPU_SubgroupBroadcastOp
2505//===----------------------------------------------------------------------===//
2506
2507void gpu::SubgroupBroadcastOp::inferResultRanges(
2508 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
2509 setResultRange(getResult(), argRanges.front());
2510}
2511
2512Speculation::Speculatability gpu::SubgroupBroadcastOp::getSpeculatability() {
2513 switch (getBroadcastType()) {
2514 case BroadcastType::first_active_lane:
2515 // Cannot speculate first_lane broadcast, because speculating it across
2516 // control flow can change the active lanes.
2518 case BroadcastType::specific_lane:
2519 // Speculation should be safe as long as we inside structured control flow.
2521 }
2522}
2523
2524LogicalResult gpu::SubgroupBroadcastOp::verify() {
2525 switch (getBroadcastType()) {
2526 case BroadcastType::first_active_lane:
2527 if (getLane())
2528 return emitOpError()
2529 << "lane can only be specified for `specific_lane` broadcast";
2530 return success();
2531 case BroadcastType::specific_lane:
2532 if (!getLane())
2533 return emitOpError()
2534 << "lane must be specified for `specific_lane` broadcast";
2535 return success();
2536 }
2537}
2538
2539//===----------------------------------------------------------------------===//
2540// GPU KernelMetadataAttr
2541//===----------------------------------------------------------------------===//
2542
2543KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2544 DictionaryAttr metadata) {
2545 assert(kernel && "invalid kernel");
2546 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2547 kernel.getAllArgAttrs(), metadata);
2548}
2549
2550KernelMetadataAttr
2551KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2552 FunctionOpInterface kernel,
2553 DictionaryAttr metadata) {
2554 assert(kernel && "invalid kernel");
2555 return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2556 kernel.getAllArgAttrs(), metadata);
2557}
2558
2559KernelMetadataAttr
2560KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2561 if (attrs.empty())
2562 return *this;
2563 NamedAttrList attrList;
2564 if (DictionaryAttr dict = getMetadata())
2565 attrList.append(dict);
2566 attrList.append(attrs);
2567 return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2568 attrList.getDictionary(getContext()));
2569}
2570
2571LogicalResult
2572KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2573 StringAttr name, Type functionType,
2574 ArrayAttr argAttrs, DictionaryAttr metadata) {
2575 if (name.empty())
2576 return emitError() << "the kernel name can't be empty";
2577 if (argAttrs) {
2578 if (llvm::any_of(argAttrs, [](Attribute attr) {
2579 return !llvm::isa<DictionaryAttr>(attr);
2580 }))
2581 return emitError()
2582 << "all attributes in the array must be a dictionary attribute";
2583 }
2584 return success();
2585}
2586
2587//===----------------------------------------------------------------------===//
2588// GPU KernelTableAttr
2589//===----------------------------------------------------------------------===//
2590
2591KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2592 ArrayRef<KernelMetadataAttr> kernels,
2593 bool isSorted) {
2594 // Note that `is_sorted` is always only invoked once even with assertions ON.
2595 assert((!isSorted || llvm::is_sorted(kernels)) &&
2596 "expected a sorted kernel array");
2597 // Immediately return the attribute if the array is sorted.
2598 if (isSorted || llvm::is_sorted(kernels))
2599 return Base::get(context, kernels);
2600 // Sort the array.
2601 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2602 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2603 return Base::get(context, kernelsTmp);
2604}
2605
2606KernelTableAttr KernelTableAttr::getChecked(
2607 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2608 ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2609 // Note that `is_sorted` is always only invoked once even with assertions ON.
2610 assert((!isSorted || llvm::is_sorted(kernels)) &&
2611 "expected a sorted kernel array");
2612 // Immediately return the attribute if the array is sorted.
2613 if (isSorted || llvm::is_sorted(kernels))
2614 return Base::getChecked(emitError, context, kernels);
2615 // Sort the array.
2616 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2617 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2618 return Base::getChecked(emitError, context, kernelsTmp);
2619}
2620
2621LogicalResult
2622KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2623 ArrayRef<KernelMetadataAttr> kernels) {
2624 if (kernels.size() < 2)
2625 return success();
2626 // Check that the kernels are uniquely named.
2627 if (std::adjacent_find(kernels.begin(), kernels.end(),
2628 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2629 return l.getName() == r.getName();
2630 }) != kernels.end()) {
2631 return emitError() << "expected all kernels to be uniquely named";
2632 }
2633 return success();
2634}
2635
2636KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2637 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2638 return found ? *iterator : KernelMetadataAttr();
2639}
2640
2641KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2642 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2643 return found ? *iterator : KernelMetadataAttr();
2644}
2645
2646//===----------------------------------------------------------------------===//
2647// GPU target options
2648//===----------------------------------------------------------------------===//
2649
2664
2682
2683TypeID TargetOptions::getTypeID() const { return typeID; }
2684
2685StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2686
2690
2691StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2692
2693StringRef TargetOptions::getELFSection() const { return elfSection; }
2694
2698
2699function_ref<void(llvm::Module &)>
2703
2704function_ref<void(llvm::Module &)>
2708
2709function_ref<void(llvm::Module &)>
2713
2715 return isaCallback;
2716}
2717
2718CompilationTarget TargetOptions::getCompilationTarget() const {
2719 return compilationTarget;
2720}
2721
2723 return CompilationTarget::Fatbin;
2724}
2725
2726std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2728 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2729 llvm::StringSaver stringSaver(options.first);
2730 StringRef opts = cmdOptions;
2731 // For a correct tokenization of the command line options `opts` must be
2732 // unquoted, otherwise the tokenization function returns a single string: the
2733 // unquoted `cmdOptions` -which is not the desired behavior.
2734 // Remove any quotes if they are at the beginning and end of the string:
2735 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2736 opts.consume_front("\""), opts.consume_back("\"");
2737 if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2738 opts.consume_front("'"), opts.consume_back("'");
2739#ifdef _WIN32
2740 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2741 /*MarkEOLs=*/false);
2742#else
2743 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2744 /*MarkEOLs=*/false);
2745#endif // _WIN32
2746 return options;
2747}
2748
2749std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2753
2754std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2756 size_t startPos = cmdOptions.find(startsWith);
2757 if (startPos == std::string::npos)
2758 return {llvm::BumpPtrAllocator(), SmallVector<const char *>()};
2759
2760 auto tokenized =
2761 tokenizeCmdOptions(cmdOptions.substr(startPos + startsWith.size()));
2762 cmdOptions.resize(startPos);
2763 return tokenized;
2764}
2765
2767
2768#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2769#include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2770
2771#define GET_ATTRDEF_CLASSES
2772#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2773
2774#define GET_OP_CLASSES
2775#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2776
2777#include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values, ArrayAttr attributes={})
static LogicalResult verifyDistributedType(Type expanded, Type distributed, int64_t warpSize, Operation *op)
Helper check if the distributed vector type is consistent with the expanded type and distributed size...
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition TypeID.h:323
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:153
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
UnitAttr getUnitAttr()
Definition Builders.cpp:98
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:76
IntegerType getI32Type()
Definition Builders.cpp:63
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition Builders.cpp:104
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:94
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:98
DialectInlinerInterface(Dialect *dialect)
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool isParent() const
Returns true if branching from the parent op.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:76
bool isIndex() const
Definition Types.cpp:54
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
function_ref< void(llvm::Module &)> optimizedLlvmIRCallback
Callback invoked with LLVM IR for the device module after LLVM optimizations but before codegen.
function_ref< void(StringRef)> getISACallback() const
Returns the callback invoked with the target ISA for the device, for example PTX assembly.
TypeID getTypeID() const
Returns the typeID.
std::string toolkitPath
Path to the target toolkit.
SymbolTable * getSymbolTable() const
Returns the result of the getSymbolTableCallback callback or a nullptr if no callback was provided.
StringRef getELFSection() const
Returns the ELF section.
StringRef getCmdOptions() const
Returns the command line options.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
function_ref< void(StringRef)> isaCallback
Callback invoked with the target ISA for the device, for example PTX assembly.
CompilationTarget compilationTarget
Compilation process target format.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
function_ref< void(llvm::Module &)> initialLlvmIRCallback
Callback invoked with the initial LLVM IR for the device module.
ArrayRef< Attribute > getLibrariesToLink() const
Returns the LLVM libraries to link to.
TargetOptions(StringRef toolkitPath={}, ArrayRef< Attribute > librariesToLink={}, StringRef cmdOptions={}, StringRef elfSection={}, CompilationTarget compilationTarget=getDefaultCompilationTarget(), function_ref< SymbolTable *()> getSymbolTableCallback={}, function_ref< void(llvm::Module &)> initialLlvmIRCallback={}, function_ref< void(llvm::Module &)> linkedLlvmIRCallback={}, function_ref< void(llvm::Module &)> optimizedLlvmIRCallback={}, function_ref< void(StringRef)> isaCallback={})
Constructor initializing the toolkit path, the list of files to link to, extra command line options,...
function_ref< void(llvm::Module &)> getOptimizedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after LLVM optimizations but before c...
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
StringRef getToolkitPath() const
Returns the toolkit path.
SmallVector< Attribute > librariesToLink
List of files to link with the LLVM module.
function_ref< void(llvm::Module &)> linkedLlvmIRCallback
Callback invoked with LLVM IR for the device module after linking the device libraries.
function_ref< void(llvm::Module &)> getInitialLlvmIRCallback() const
Returns the callback invoked with the initial LLVM IR for the device module.
function_ref< SymbolTable *()> getSymbolTableCallback
Callback for obtaining the parent symbol table of all the GPU modules being serialized.
static CompilationTarget getDefaultCompilationTarget()
Returns the default compilation target: CompilationTarget::Fatbin.
function_ref< void(llvm::Module &)> getLinkedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after linking the device libraries.
std::string elfSection
ELF Section where the binary needs to be located.
CompilationTarget getCompilationTarget() const
Returns the compilation target.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
std::pair< IteratorT, bool > findAttrSorted(IteratorT first, IteratorT last, StringRef name)
Using llvm::lower_bound requires an extra string comparison to check whether the returned iterator po...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:45
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition GPUDialect.h:39