MLIR  20.0.0git
GPUDialect.cpp
Go to the documentation of this file.
1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU kernel-related dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/SymbolTable.h"
29 #include "mlir/IR/TypeUtilities.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/ErrorHandling.h"
37 #include "llvm/Support/StringSaver.h"
38 #include <cassert>
39 
40 using namespace mlir;
41 using namespace mlir::gpu;
42 
43 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
44 
45 //===----------------------------------------------------------------------===//
46 // GPU Device Mapping Attributes
47 //===----------------------------------------------------------------------===//
48 
49 int64_t GPUBlockMappingAttr::getMappingId() const {
50  return static_cast<int64_t>(getBlock());
51 }
52 
53 bool GPUBlockMappingAttr::isLinearMapping() const {
54  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
55 }
56 
57 int64_t GPUBlockMappingAttr::getRelativeIndex() const {
58  return isLinearMapping()
59  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
60  : getMappingId();
61 }
62 
63 int64_t GPUWarpgroupMappingAttr::getMappingId() const {
64  return static_cast<int64_t>(getWarpgroup());
65 }
66 
67 bool GPUWarpgroupMappingAttr::isLinearMapping() const {
68  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
69 }
70 
71 int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
72  return isLinearMapping()
73  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
74  : getMappingId();
75 }
76 
77 int64_t GPUWarpMappingAttr::getMappingId() const {
78  return static_cast<int64_t>(getWarp());
79 }
80 
81 bool GPUWarpMappingAttr::isLinearMapping() const {
82  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
83 }
84 
85 int64_t GPUWarpMappingAttr::getRelativeIndex() const {
86  return isLinearMapping()
87  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
88  : getMappingId();
89 }
90 
91 int64_t GPUThreadMappingAttr::getMappingId() const {
92  return static_cast<int64_t>(getThread());
93 }
94 
95 bool GPUThreadMappingAttr::isLinearMapping() const {
96  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
97 }
98 
99 int64_t GPUThreadMappingAttr::getRelativeIndex() const {
100  return isLinearMapping()
101  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
102  : getMappingId();
103 }
104 
105 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
106  return static_cast<int64_t>(getAddressSpace());
107 }
108 
109 bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
110  llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
111 }
112 
113 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
114  llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
115 }
116 
117 //===----------------------------------------------------------------------===//
118 // MMAMatrixType
119 //===----------------------------------------------------------------------===//
120 
122  StringRef operand) {
123  return Base::get(elementType.getContext(), shape, elementType, operand);
124 }
125 
128  ArrayRef<int64_t> shape, Type elementType,
129  StringRef operand) {
130  return Base::getChecked(emitError, elementType.getContext(), shape,
131  elementType, operand);
132 }
133 
134 unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
135 
137  return getImpl()->getShape();
138 }
139 
140 Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
141 
142 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
143 
145  return elementType.isF16() || elementType.isF32() ||
146  elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
147  elementType.isInteger(32);
148 }
149 
150 LogicalResult
152  ArrayRef<int64_t> shape, Type elementType,
153  StringRef operand) {
154  if (operand != "AOp" && operand != "BOp" && operand != "COp")
155  return emitError() << "operand expected to be one of AOp, BOp or COp";
156 
157  if (shape.size() != 2)
158  return emitError() << "MMAMatrixType must have exactly two dimensions";
159 
160  if (!MMAMatrixType::isValidElementType(elementType))
161  return emitError()
162  << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
163 
164  return success();
165 }
166 
167 //===----------------------------------------------------------------------===//
168 // GPUDialect
169 //===----------------------------------------------------------------------===//
170 
171 bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
172  if (!memorySpace)
173  return false;
174  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
175  return gpuAttr.getValue() == getWorkgroupAddressSpace();
176  return false;
177 }
178 
179 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
180  Attribute memorySpace = type.getMemorySpace();
181  return isWorkgroupMemoryAddressSpace(memorySpace);
182 }
183 
184 bool GPUDialect::isKernel(Operation *op) {
185  UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
186  return static_cast<bool>(isKernelAttr);
187 }
188 
189 namespace {
190 /// This class defines the interface for handling inlining with gpu
191 /// operations.
192 struct GPUInlinerInterface : public DialectInlinerInterface {
194 
195  /// All gpu dialect ops can be inlined.
196  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
197  return true;
198  }
199 };
200 } // namespace
201 
202 void GPUDialect::initialize() {
203  addTypes<AsyncTokenType>();
204  addTypes<MMAMatrixType>();
205  addTypes<SparseDnTensorHandleType>();
206  addTypes<SparseSpMatHandleType>();
207  addTypes<SparseSpGEMMOpHandleType>();
208  addOperations<
209 #define GET_OP_LIST
210 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
211  >();
212  addAttributes<
213 #define GET_ATTRDEF_LIST
214 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
215  >();
216  addInterfaces<GPUInlinerInterface>();
217  declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
218  TerminatorOp>();
219 }
220 
221 static std::string getSparseHandleKeyword(SparseHandleKind kind) {
222  switch (kind) {
224  return "sparse.dntensor_handle";
226  return "sparse.spmat_handle";
228  return "sparse.spgemmop_handle";
229  }
230  llvm_unreachable("unknown sparse handle kind");
231  return "";
232 }
233 
235  // Parse the main keyword for the type.
236  StringRef keyword;
237  if (parser.parseKeyword(&keyword))
238  return Type();
239  MLIRContext *context = getContext();
240 
241  // Handle 'async token' types.
242  if (keyword == "async.token")
243  return AsyncTokenType::get(context);
244 
245  if (keyword == "mma_matrix") {
246  SMLoc beginLoc = parser.getNameLoc();
247 
248  // Parse '<'.
249  if (parser.parseLess())
250  return nullptr;
251 
252  // Parse the size and elementType.
253  SmallVector<int64_t> shape;
254  Type elementType;
255  if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
256  parser.parseType(elementType))
257  return nullptr;
258 
259  // Parse ','
260  if (parser.parseComma())
261  return nullptr;
262 
263  // Parse operand.
264  std::string operand;
265  if (failed(parser.parseOptionalString(&operand)))
266  return nullptr;
267 
268  // Parse '>'.
269  if (parser.parseGreater())
270  return nullptr;
271 
273  parser.getEncodedSourceLoc(beginLoc)),
274  shape, elementType, operand);
275  }
276 
278  return SparseDnTensorHandleType::get(context);
280  return SparseSpMatHandleType::get(context);
282  return SparseSpGEMMOpHandleType::get(context);
283 
284  parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
285  return Type();
286 }
287 // TODO: print refined type here. Notice that should be corresponding to the
288 // parser
289 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
290  TypeSwitch<Type>(type)
291  .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
292  .Case<SparseDnTensorHandleType>([&](Type) {
294  })
295  .Case<SparseSpMatHandleType>(
297  .Case<SparseSpGEMMOpHandleType>([&](Type) {
299  })
300  .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
301  os << "mma_matrix<";
302  auto shape = fragTy.getShape();
303  for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
304  os << *dim << 'x';
305  os << shape.back() << 'x' << fragTy.getElementType();
306  os << ", \"" << fragTy.getOperand() << "\"" << '>';
307  })
308  .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
309 }
310 
311 static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
312  NamedAttribute attr) {
313  auto array = dyn_cast<DenseI32ArrayAttr>(attr.getValue());
314  if (!array)
315  return op->emitOpError(Twine(attr.getName()) +
316  " must be a dense i32 array");
317  if (array.size() != 3)
318  return op->emitOpError(Twine(attr.getName()) +
319  " must contain exactly 3 elements");
320  return success();
321 }
322 
323 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
324  NamedAttribute attr) {
325  if (attr.getName() == getKnownBlockSizeAttrHelper().getName())
326  return verifyKnownLaunchSizeAttr(op, attr);
327  if (attr.getName() == getKnownGridSizeAttrHelper().getName())
328  return verifyKnownLaunchSizeAttr(op, attr);
329  if (!llvm::isa<UnitAttr>(attr.getValue()) ||
330  attr.getName() != getContainerModuleAttrName())
331  return success();
332 
333  auto module = dyn_cast<ModuleOp>(op);
334  if (!module)
335  return op->emitError("expected '")
336  << getContainerModuleAttrName() << "' attribute to be attached to '"
337  << ModuleOp::getOperationName() << '\'';
338 
339  auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
340  // Ignore launches that are nested more or less deep than functions in the
341  // module we are currently checking.
342  if (!launchOp->getParentOp() ||
343  launchOp->getParentOp()->getParentOp() != module)
344  return success();
345 
346  // Ignore launch ops with missing attributes here. The errors will be
347  // reported by the verifiers of those ops.
348  if (!launchOp->getAttrOfType<SymbolRefAttr>(
349  LaunchFuncOp::getKernelAttrName(launchOp->getName())))
350  return success();
351 
352  // Check that `launch_func` refers to a well-formed GPU kernel container.
353  StringAttr kernelContainerName = launchOp.getKernelModuleName();
354  Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
355  if (!kernelContainer)
356  return launchOp.emitOpError()
357  << "kernel container '" << kernelContainerName.getValue()
358  << "' is undefined";
359 
360  // If the container is a GPU binary op return success.
361  if (isa<BinaryOp>(kernelContainer))
362  return success();
363 
364  auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
365  if (!kernelModule)
366  return launchOp.emitOpError()
367  << "kernel module '" << kernelContainerName.getValue()
368  << "' is undefined";
369 
370  // Check that `launch_func` refers to a well-formed kernel function.
371  Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
372  if (!kernelFunc)
373  return launchOp.emitOpError("kernel function '")
374  << launchOp.getKernel() << "' is undefined";
375  auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
376  if (!kernelConvertedFunction) {
377  InFlightDiagnostic diag = launchOp.emitOpError()
378  << "referenced kernel '" << launchOp.getKernel()
379  << "' is not a function";
380  diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
381  return diag;
382  }
383 
384  if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
385  GPUDialect::getKernelFuncAttrName()))
386  return launchOp.emitOpError("kernel function is missing the '")
387  << GPUDialect::getKernelFuncAttrName() << "' attribute";
388 
389  // TODO: If the kernel isn't a GPU function (which happens during separate
390  // compilation), do not check type correspondence as it would require the
391  // verifier to be aware of the type conversion.
392  auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
393  if (!kernelGPUFunction)
394  return success();
395 
396  unsigned actualNumArguments = launchOp.getNumKernelOperands();
397  unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
398  if (expectedNumArguments != actualNumArguments)
399  return launchOp.emitOpError("got ")
400  << actualNumArguments << " kernel operands but expected "
401  << expectedNumArguments;
402 
403  auto functionType = kernelGPUFunction.getFunctionType();
404  for (unsigned i = 0; i < expectedNumArguments; ++i) {
405  if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
406  return launchOp.emitOpError("type of function argument ")
407  << i << " does not match";
408  }
409  }
410 
411  return success();
412  });
413 
414  return walkResult.wasInterrupted() ? failure() : success();
415 }
416 
417 /// Parses an optional list of async operands with an optional leading keyword.
418 /// (`async`)? (`[` ssa-id-list `]`)?
419 ///
420 /// This method is used by the tablegen assembly format for async ops as well.
421 static ParseResult parseAsyncDependencies(
422  OpAsmParser &parser, Type &asyncTokenType,
424  auto loc = parser.getCurrentLocation();
425  if (succeeded(parser.parseOptionalKeyword("async"))) {
426  if (parser.getNumResults() == 0)
427  return parser.emitError(loc, "needs to be named when marked 'async'");
428  asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
429  }
430  return parser.parseOperandList(asyncDependencies,
432 }
433 
434 /// Prints optional async dependencies with its leading keyword.
435 /// (`async`)? (`[` ssa-id-list `]`)?
436 // Used by the tablegen assembly format for several async ops.
438  Type asyncTokenType,
439  OperandRange asyncDependencies) {
440  if (asyncTokenType)
441  printer << "async";
442  if (asyncDependencies.empty())
443  return;
444  if (asyncTokenType)
445  printer << ' ';
446  printer << '[';
447  llvm::interleaveComma(asyncDependencies, printer);
448  printer << ']';
449 }
450 
451 // GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
452 /// Parses a GPU function memory attribution.
453 ///
454 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
455 /// (`private` `(` ssa-id-and-type-list `)`)?
456 ///
457 /// Note that this function parses only one of the two similar parts, with the
458 /// keyword provided as argument.
459 static ParseResult
460 parseAttributions(OpAsmParser &parser, StringRef keyword,
462  // If we could not parse the keyword, just assume empty list and succeed.
463  if (failed(parser.parseOptionalKeyword(keyword)))
464  return success();
465 
467  /*allowType=*/true);
468 }
469 
470 /// Prints a GPU function memory attribution.
471 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
472  ArrayRef<BlockArgument> values) {
473  if (values.empty())
474  return;
475 
476  p << ' ' << keyword << '(';
477  llvm::interleaveComma(
478  values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); });
479  p << ')';
480 }
481 
482 /// Verifies a GPU function memory attribution.
483 static LogicalResult verifyAttributions(Operation *op,
484  ArrayRef<BlockArgument> attributions,
485  gpu::AddressSpace memorySpace) {
486  for (Value v : attributions) {
487  auto type = llvm::dyn_cast<MemRefType>(v.getType());
488  if (!type)
489  return op->emitOpError() << "expected memref type in attribution";
490 
491  // We can only verify the address space if it hasn't already been lowered
492  // from the AddressSpaceAttr to a target-specific numeric value.
493  auto addressSpace =
494  llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
495  if (!addressSpace)
496  continue;
497  if (addressSpace.getValue() != memorySpace)
498  return op->emitOpError()
499  << "expected memory space " << stringifyAddressSpace(memorySpace)
500  << " in attribution";
501  }
502  return success();
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // AllReduceOp
507 //===----------------------------------------------------------------------===//
508 
509 static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
510  Type resType) {
511  using Kind = gpu::AllReduceOperation;
512  if (llvm::is_contained(
513  {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
514  opName)) {
515  if (!isa<FloatType>(resType))
516  return failure();
517  }
518 
519  if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
520  Kind::AND, Kind::OR, Kind::XOR},
521  opName)) {
522  if (!isa<IntegerType>(resType))
523  return failure();
524  }
525 
526  return success();
527 }
528 
529 LogicalResult gpu::AllReduceOp::verifyRegions() {
530  if (getBody().empty() != getOp().has_value())
531  return emitError("expected either an op attribute or a non-empty body");
532  if (!getBody().empty()) {
533  if (getBody().getNumArguments() != 2)
534  return emitError("expected two region arguments");
535  for (auto argument : getBody().getArguments()) {
536  if (argument.getType() != getType())
537  return emitError("incorrect region argument type");
538  }
539  unsigned yieldCount = 0;
540  for (Block &block : getBody()) {
541  if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
542  if (yield.getNumOperands() != 1)
543  return emitError("expected one gpu.yield operand");
544  if (yield.getOperand(0).getType() != getType())
545  return emitError("incorrect gpu.yield type");
546  ++yieldCount;
547  }
548  }
549  if (yieldCount == 0)
550  return emitError("expected gpu.yield op in region");
551  } else {
552  gpu::AllReduceOperation opName = *getOp();
553  if (failed(verifyReduceOpAndType(opName, getType()))) {
554  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
555  << "` reduction operation is not compatible with type "
556  << getType();
557  }
558  }
559 
560  return success();
561 }
562 
564  auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
565  if (!launchOp)
566  return false;
567 
568  Region &body = launchOp.getBody();
569  assert(!body.empty() && "Invalid region");
570 
571  // Only convert ops in gpu::launch entry block for now.
572  return op->getBlock() == &body.front();
573 }
574 
575 OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
576  if (!getUniform() && canMakeGroupOpUniform(*this)) {
577  setUniform(true);
578  return getResult();
579  }
580 
581  return nullptr;
582 }
583 
584 // TODO: Support optional custom attributes (without dialect prefix).
585 static ParseResult parseAllReduceOperation(AsmParser &parser,
586  AllReduceOperationAttr &attr) {
587  StringRef enumStr;
588  if (!parser.parseOptionalKeyword(&enumStr)) {
589  std::optional<AllReduceOperation> op =
590  gpu::symbolizeAllReduceOperation(enumStr);
591  if (!op)
592  return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
593  attr = AllReduceOperationAttr::get(parser.getContext(), *op);
594  }
595  return success();
596 }
597 
598 static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
599  AllReduceOperationAttr attr) {
600  if (attr)
601  attr.print(printer);
602 }
603 
604 //===----------------------------------------------------------------------===//
605 // SubgroupReduceOp
606 //===----------------------------------------------------------------------===//
607 
608 LogicalResult gpu::SubgroupReduceOp::verify() {
609  Type elemType = getType();
610  if (auto vecTy = dyn_cast<VectorType>(elemType)) {
611  if (vecTy.isScalable())
612  return emitOpError() << "is not compatible with scalable vector types";
613 
614  elemType = vecTy.getElementType();
615  }
616 
617  gpu::AllReduceOperation opName = getOp();
618  if (failed(verifyReduceOpAndType(opName, elemType))) {
619  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
620  << "` reduction operation is not compatible with type "
621  << getType();
622  }
623  return success();
624 }
625 
626 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
627  if (!getUniform() && canMakeGroupOpUniform(*this)) {
628  setUniform(true);
629  return getResult();
630  }
631 
632  return nullptr;
633 }
634 
635 //===----------------------------------------------------------------------===//
636 // AsyncOpInterface
637 //===----------------------------------------------------------------------===//
638 
640  op->insertOperands(0, {token});
641  if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
642  return;
643  auto attrName =
645  auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
646 
647  // Async dependencies is the only variadic operand.
648  if (!sizeAttr)
649  return;
650 
651  SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
652  ++sizes.front();
653  op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
654 }
655 
656 //===----------------------------------------------------------------------===//
657 // LaunchOp
658 //===----------------------------------------------------------------------===//
659 
660 void LaunchOp::build(OpBuilder &builder, OperationState &result,
661  Value gridSizeX, Value gridSizeY, Value gridSizeZ,
662  Value getBlockSizeX, Value getBlockSizeY,
663  Value getBlockSizeZ, Value dynamicSharedMemorySize,
664  Type asyncTokenType, ValueRange asyncDependencies,
665  TypeRange workgroupAttributions,
666  TypeRange privateAttributions, Value clusterSizeX,
667  Value clusterSizeY, Value clusterSizeZ) {
668  OpBuilder::InsertionGuard g(builder);
669 
670  // Add a WorkGroup attribution attribute. This attribute is required to
671  // identify private attributions in the list of block argguments.
672  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
673  builder.getI64IntegerAttr(workgroupAttributions.size()));
674 
675  // Add Op operands.
676  result.addOperands(asyncDependencies);
677  if (asyncTokenType)
678  result.types.push_back(builder.getType<AsyncTokenType>());
679 
680  // Add grid and block sizes as op operands, followed by the data operands.
681  result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
682  getBlockSizeY, getBlockSizeZ});
683  if (clusterSizeX)
684  result.addOperands(clusterSizeX);
685  if (clusterSizeY)
686  result.addOperands(clusterSizeY);
687  if (clusterSizeZ)
688  result.addOperands(clusterSizeZ);
689  if (dynamicSharedMemorySize)
690  result.addOperands(dynamicSharedMemorySize);
691 
692  // Create a kernel body region with kNumConfigRegionAttributes + N memory
693  // attributions, where the first kNumConfigRegionAttributes arguments have
694  // `index` type and the rest have the same types as the data operands.
695  Region *kernelRegion = result.addRegion();
696  Block *body = builder.createBlock(kernelRegion);
697  // TODO: Allow passing in proper locations here.
698  for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
699  body->addArgument(builder.getIndexType(), result.location);
700  // Add WorkGroup & Private attributions to the region arguments.
701  for (Type argTy : workgroupAttributions)
702  body->addArgument(argTy, result.location);
703  for (Type argTy : privateAttributions)
704  body->addArgument(argTy, result.location);
705  // Fill OperandSegmentSize Attribute.
706  SmallVector<int32_t, 11> segmentSizes(11, 1);
707  segmentSizes.front() = asyncDependencies.size();
708  segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
709  segmentSizes[7] = clusterSizeX ? 1 : 0;
710  segmentSizes[8] = clusterSizeY ? 1 : 0;
711  segmentSizes[9] = clusterSizeZ ? 1 : 0;
712  result.addAttribute(getOperandSegmentSizeAttr(),
713  builder.getDenseI32ArrayAttr(segmentSizes));
714 }
715 
716 KernelDim3 LaunchOp::getBlockIds() {
717  assert(!getBody().empty() && "LaunchOp body must not be empty.");
718  auto args = getBody().getArguments();
719  return KernelDim3{args[0], args[1], args[2]};
720 }
721 
722 KernelDim3 LaunchOp::getThreadIds() {
723  assert(!getBody().empty() && "LaunchOp body must not be empty.");
724  auto args = getBody().getArguments();
725  return KernelDim3{args[3], args[4], args[5]};
726 }
727 
728 KernelDim3 LaunchOp::getGridSize() {
729  assert(!getBody().empty() && "LaunchOp body must not be empty.");
730  auto args = getBody().getArguments();
731  return KernelDim3{args[6], args[7], args[8]};
732 }
733 
735  assert(!getBody().empty() && "LaunchOp body must not be empty.");
736  auto args = getBody().getArguments();
737  return KernelDim3{args[9], args[10], args[11]};
738 }
739 
740 std::optional<KernelDim3> LaunchOp::getClusterIds() {
741  assert(!getBody().empty() && "LaunchOp body must not be empty.");
742  if (!hasClusterSize())
743  return std::nullopt;
744  auto args = getBody().getArguments();
745  return KernelDim3{args[12], args[13], args[14]};
746 }
747 
748 std::optional<KernelDim3> LaunchOp::getClusterSize() {
749  assert(!getBody().empty() && "LaunchOp body must not be empty.");
750  if (!hasClusterSize())
751  return std::nullopt;
752  auto args = getBody().getArguments();
753  return KernelDim3{args[15], args[16], args[17]};
754 }
755 
756 KernelDim3 LaunchOp::getGridSizeOperandValues() {
757  auto operands = getOperands().drop_front(getAsyncDependencies().size());
758  return KernelDim3{operands[0], operands[1], operands[2]};
759 }
760 
761 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
762  auto operands = getOperands().drop_front(getAsyncDependencies().size());
763  return KernelDim3{operands[3], operands[4], operands[5]};
764 }
765 
766 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
767  auto operands = getOperands().drop_front(getAsyncDependencies().size());
768  if (!hasClusterSize())
769  return std::nullopt;
770  return KernelDim3{operands[6], operands[7], operands[8]};
771 }
772 
773 LogicalResult LaunchOp::verify() {
774  if (!(hasClusterSize()) &&
775  (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
776  return emitOpError() << "cluster size must be all present";
777  return success();
778 }
779 
780 LogicalResult LaunchOp::verifyRegions() {
781  // Kernel launch takes kNumConfigOperands leading operands for grid/block
782  // sizes and transforms them into kNumConfigRegionAttributes region arguments
783  // for block/thread identifiers and grid/block sizes.
784  if (!getBody().empty()) {
785  if (getBody().getNumArguments() <
786  kNumConfigRegionAttributes + getNumWorkgroupAttributions())
787  return emitOpError("unexpected number of region arguments");
788  }
789 
790  // Verify Attributions Address Spaces.
791  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
792  GPUDialect::getWorkgroupAddressSpace())) ||
793  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
794  GPUDialect::getPrivateAddressSpace())))
795  return failure();
796 
797  // Block terminators without successors are expected to exit the kernel region
798  // and must be `gpu.terminator`.
799  for (Block &block : getBody()) {
800  if (block.empty())
801  continue;
802  if (block.back().getNumSuccessors() != 0)
803  continue;
804  if (!isa<gpu::TerminatorOp>(&block.back())) {
805  return block.back()
806  .emitError()
807  .append("expected '", gpu::TerminatorOp::getOperationName(),
808  "' or a terminator with successors")
809  .attachNote(getLoc())
810  .append("in '", LaunchOp::getOperationName(), "' body region");
811  }
812  }
813 
814  if (getNumResults() == 0 && getAsyncToken())
815  return emitOpError("needs to be named when async keyword is specified");
816 
817  return success();
818 }
819 
820 // Pretty-print the kernel grid/block size assignment as
821 // (%iter-x, %iter-y, %iter-z) in
822 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
823 // where %size-* and %iter-* will correspond to the body region arguments.
825  KernelDim3 operands, KernelDim3 ids) {
826  p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
827  p << size.x << " = " << operands.x << ", ";
828  p << size.y << " = " << operands.y << ", ";
829  p << size.z << " = " << operands.z << ')';
830 }
831 
832 void LaunchOp::print(OpAsmPrinter &p) {
833  if (getAsyncToken()) {
834  p << " async";
835  if (!getAsyncDependencies().empty())
836  p << " [" << getAsyncDependencies() << ']';
837  }
838  // Print the launch configuration.
839  if (hasClusterSize()) {
840  p << ' ' << getClustersKeyword();
841  printSizeAssignment(p, getClusterSize().value(),
842  getClusterSizeOperandValues().value(),
843  getClusterIds().value());
844  }
845  p << ' ' << getBlocksKeyword();
846  printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
847  getBlockIds());
848  p << ' ' << getThreadsKeyword();
849  printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
850  getThreadIds());
851  if (getDynamicSharedMemorySize())
852  p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
853  << getDynamicSharedMemorySize();
854 
855  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
856  printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
857 
858  p << ' ';
859 
860  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
861  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
862  LaunchOp::getOperandSegmentSizeAttr(),
863  getNumWorkgroupAttributionsAttrName()});
864 }
865 
866 // Parse the size assignment blocks for blocks and threads. These have the form
867 // (%region_arg, %region_arg, %region_arg) in
868 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
869 // where %region_arg are percent-identifiers for the region arguments to be
870 // introduced further (SSA defs), and %operand are percent-identifiers for the
871 // SSA value uses.
872 static ParseResult
877  assert(indices.size() == 3 && "space for three indices expected");
880  /*allowResultNumber=*/false) ||
881  parser.parseKeyword("in") || parser.parseLParen())
882  return failure();
883  std::move(args.begin(), args.end(), indices.begin());
884 
885  for (int i = 0; i < 3; ++i) {
886  if (i != 0 && parser.parseComma())
887  return failure();
888  if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
889  parser.parseEqual() || parser.parseOperand(sizes[i]))
890  return failure();
891  }
892 
893  return parser.parseRParen();
894 }
895 
896 /// Parses a Launch operation.
897 /// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
898 /// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
899 /// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
900 /// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
901 /// memory-attribution
902 /// region attr-dict?
903 /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
904 ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
905  // Sizes of the grid and block.
907  sizes(LaunchOp::kNumConfigOperands);
908 
909  // Actual (data) operands passed to the kernel.
911 
912  // Region arguments to be created.
914  LaunchOp::kNumConfigRegionAttributes);
915 
916  // Parse optional async dependencies.
918  Type asyncTokenType;
919  if (failed(
920  parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
921  parser.resolveOperands(asyncDependencies, asyncTokenType,
922  result.operands))
923  return failure();
924  if (parser.getNumResults() > 0)
925  result.types.push_back(asyncTokenType);
926 
927  bool hasCluster = false;
928  if (succeeded(
929  parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) {
930  hasCluster = true;
931  sizes.resize(9);
932  regionArgs.resize(18);
933  }
935  MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
936 
937  // Last three segment assigns the cluster size. In the region argument
938  // list, this is last 6 arguments.
939  if (hasCluster) {
940  if (parseSizeAssignment(parser, sizesRef.drop_front(6),
941  regionArgsRef.slice(15, 3),
942  regionArgsRef.slice(12, 3)))
943  return failure();
944  }
945  // Parse the size assignment segments: the first segment assigns grid sizes
946  // and defines values for block identifiers; the second segment assigns block
947  // sizes and defines values for thread identifiers. In the region argument
948  // list, identifiers precede sizes, and block-related values precede
949  // thread-related values.
950  if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
951  parseSizeAssignment(parser, sizesRef.take_front(3),
952  regionArgsRef.slice(6, 3),
953  regionArgsRef.slice(0, 3)) ||
954  parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
955  parseSizeAssignment(parser, sizesRef.drop_front(3),
956  regionArgsRef.slice(9, 3),
957  regionArgsRef.slice(3, 3)) ||
958  parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
959  result.operands))
960  return failure();
961 
962  OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
963  bool hasDynamicSharedMemorySize = false;
964  if (!parser.parseOptionalKeyword(
965  LaunchOp::getDynamicSharedMemorySizeKeyword())) {
966  hasDynamicSharedMemorySize = true;
967  if (parser.parseOperand(dynamicSharedMemorySize) ||
968  parser.resolveOperand(dynamicSharedMemorySize,
969  parser.getBuilder().getI32Type(),
970  result.operands))
971  return failure();
972  }
973 
974  // Create the region arguments, it has kNumConfigRegionAttributes arguments
975  // that correspond to block/thread identifiers and grid/block sizes, all
976  // having `index` type, a variadic number of WorkGroup Attributions and
977  // a variadic number of Private Attributions. The number of WorkGroup
978  // Attributions is stored in the attr with name:
979  // LaunchOp::getNumWorkgroupAttributionsAttrName().
980  Type index = parser.getBuilder().getIndexType();
982  LaunchOp::kNumConfigRegionAttributes + 6, index);
983 
984  SmallVector<OpAsmParser::Argument> regionArguments;
985  for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
987  arg.ssaName = std::get<0>(ssaValueAndType);
988  arg.type = std::get<1>(ssaValueAndType);
989  regionArguments.push_back(arg);
990  }
991 
992  Builder &builder = parser.getBuilder();
993  // Parse workgroup memory attributions.
994  if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
995  regionArguments)))
996  return failure();
997 
998  // Store the number of operands we just parsed as the number of workgroup
999  // memory attributions.
1000  unsigned numWorkgroupAttrs = regionArguments.size() -
1001  LaunchOp::kNumConfigRegionAttributes -
1002  (hasCluster ? 6 : 0);
1003  result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1004  builder.getI64IntegerAttr(numWorkgroupAttrs));
1005 
1006  // Parse private memory attributions.
1007  if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
1008  regionArguments)))
1009  return failure();
1010 
1011  // Introduce the body region and parse it. The region has
1012  // kNumConfigRegionAttributes arguments that correspond to
1013  // block/thread identifiers and grid/block sizes, all having `index` type.
1014  Region *body = result.addRegion();
1015  if (parser.parseRegion(*body, regionArguments) ||
1016  parser.parseOptionalAttrDict(result.attributes))
1017  return failure();
1018 
1019  SmallVector<int32_t, 11> segmentSizes(11, 1);
1020  segmentSizes.front() = asyncDependencies.size();
1021 
1022  if (!hasCluster) {
1023  segmentSizes[7] = 0;
1024  segmentSizes[8] = 0;
1025  segmentSizes[9] = 0;
1026  }
1027  segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1028  result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1029  parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1030  return success();
1031 }
1032 
1033 /// Simplify the gpu.launch when the range of a thread or block ID is
1034 /// trivially known to be one.
1035 struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1037  LogicalResult matchAndRewrite(LaunchOp op,
1038  PatternRewriter &rewriter) const override {
1039  // If the range implies a single value for `id`, replace `id`'s uses by
1040  // zero.
1041  Value zero;
1042  bool simplified = false;
1043  auto constPropIdUses = [&](Value id, Value size) {
1044  // Check if size is trivially one.
1045  if (!matchPattern(size, m_One()))
1046  return;
1047  if (id.getUses().empty())
1048  return;
1049  if (!simplified) {
1050  // Create a zero value the first time.
1051  OpBuilder::InsertionGuard guard(rewriter);
1052  rewriter.setInsertionPointToStart(&op.getBody().front());
1053  zero =
1054  rewriter.create<arith::ConstantIndexOp>(op.getLoc(), /*value=*/0);
1055  }
1056  rewriter.replaceAllUsesWith(id, zero);
1057  simplified = true;
1058  };
1059  constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1060  constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1061  constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1062  constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1063  constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1064  constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1065 
1066  return success(simplified);
1067  }
1068 };
1069 
1070 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1071  MLIRContext *context) {
1072  rewrites.add<FoldLaunchArguments>(context);
1073 }
1074 
1075 /// Adds a new block argument that corresponds to buffers located in
1076 /// workgroup memory.
1077 BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1078  auto attrName = getNumWorkgroupAttributionsAttrName();
1079  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1080  (*this)->setAttr(attrName,
1081  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1082  return getBody().insertArgument(
1083  LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1084 }
1085 
1086 /// Adds a new block argument that corresponds to buffers located in
1087 /// private memory.
1088 BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1089  // Buffers on the private memory always come after buffers on the workgroup
1090  // memory.
1091  return getBody().addArgument(type, loc);
1092 }
1093 
1094 //===----------------------------------------------------------------------===//
1095 // LaunchFuncOp
1096 //===----------------------------------------------------------------------===//
1097 
1098 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1099  SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1100  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1101  ValueRange kernelOperands, Type asyncTokenType,
1102  ValueRange asyncDependencies,
1103  std::optional<KernelDim3> clusterSize) {
1104  assert(kernelSymbol.getNestedReferences().size() == 1 &&
1105  "expected a symbol reference with a single nested reference");
1106  result.addOperands(asyncDependencies);
1107  if (asyncTokenType)
1108  result.types.push_back(builder.getType<AsyncTokenType>());
1109 
1110  // Add grid and block sizes as op operands, followed by the data operands.
1111  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1112  getBlockSize.y, getBlockSize.z});
1113  if (clusterSize.has_value())
1114  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1115  if (dynamicSharedMemorySize)
1116  result.addOperands(dynamicSharedMemorySize);
1117  result.addOperands(kernelOperands);
1118 
1119  Properties &prop = result.getOrAddProperties<Properties>();
1120  prop.kernel = kernelSymbol;
1121  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1122  // Initialize the segment sizes to 1.
1123  for (auto &sz : prop.operandSegmentSizes)
1124  sz = 1;
1125  prop.operandSegmentSizes[0] = asyncDependencies.size();
1126  if (!clusterSize.has_value()) {
1127  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1128  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1129  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1130  }
1131  prop.operandSegmentSizes[segmentSizesLen - 3] =
1132  dynamicSharedMemorySize ? 1 : 0;
1133  prop.operandSegmentSizes[segmentSizesLen - 2] =
1134  static_cast<int32_t>(kernelOperands.size());
1135  prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1136 }
1137 
1138 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1139  GPUFuncOp kernelFunc, KernelDim3 gridSize,
1140  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1141  ValueRange kernelOperands, Type asyncTokenType,
1142  ValueRange asyncDependencies,
1143  std::optional<KernelDim3> clusterSize) {
1144  auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1145  auto kernelSymbol =
1146  SymbolRefAttr::get(kernelModule.getNameAttr(),
1147  {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1148  build(builder, result, kernelSymbol, gridSize, getBlockSize,
1149  dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1150  asyncDependencies, clusterSize);
1151 }
1152 
1153 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1154  SymbolRefAttr kernel, KernelDim3 gridSize,
1155  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1156  ValueRange kernelOperands, Value asyncObject,
1157  std::optional<KernelDim3> clusterSize) {
1158  // Add grid and block sizes as op operands, followed by the data operands.
1159  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1160  getBlockSize.y, getBlockSize.z});
1161  if (clusterSize.has_value())
1162  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1163  if (dynamicSharedMemorySize)
1164  result.addOperands(dynamicSharedMemorySize);
1165  result.addOperands(kernelOperands);
1166  if (asyncObject)
1167  result.addOperands(asyncObject);
1168  Properties &prop = result.getOrAddProperties<Properties>();
1169  prop.kernel = kernel;
1170  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1171  // Initialize the segment sizes to 1.
1172  for (auto &sz : prop.operandSegmentSizes)
1173  sz = 1;
1174  prop.operandSegmentSizes[0] = 0;
1175  if (!clusterSize.has_value()) {
1176  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1177  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1178  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1179  }
1180  prop.operandSegmentSizes[segmentSizesLen - 3] =
1181  dynamicSharedMemorySize ? 1 : 0;
1182  prop.operandSegmentSizes[segmentSizesLen - 2] =
1183  static_cast<int32_t>(kernelOperands.size());
1184  prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1185 }
1186 
1187 StringAttr LaunchFuncOp::getKernelModuleName() {
1188  return getKernel().getRootReference();
1189 }
1190 
1191 StringAttr LaunchFuncOp::getKernelName() {
1192  return getKernel().getLeafReference();
1193 }
1194 
1195 unsigned LaunchFuncOp::getNumKernelOperands() {
1196  return getKernelOperands().size();
1197 }
1198 
1199 Value LaunchFuncOp::getKernelOperand(unsigned i) {
1200  return getKernelOperands()[i];
1201 }
1202 
1203 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1204  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1205  return KernelDim3{operands[0], operands[1], operands[2]};
1206 }
1207 
1208 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1209  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1210  return KernelDim3{operands[3], operands[4], operands[5]};
1211 }
1212 
1213 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1214  assert(hasClusterSize() &&
1215  "cluster size is not set, check hasClusterSize() first");
1216  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1217  return KernelDim3{operands[6], operands[7], operands[8]};
1218 }
1219 
1220 LogicalResult LaunchFuncOp::verify() {
1221  auto module = (*this)->getParentOfType<ModuleOp>();
1222  if (!module)
1223  return emitOpError("expected to belong to a module");
1224 
1225  if (!module->getAttrOfType<UnitAttr>(
1226  GPUDialect::getContainerModuleAttrName()))
1227  return emitOpError("expected the closest surrounding module to have the '" +
1228  GPUDialect::getContainerModuleAttrName() +
1229  "' attribute");
1230 
1231  if (hasClusterSize()) {
1232  if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1233  getClusterSizeZ().getType() != getClusterSizeX().getType())
1234  return emitOpError()
1235  << "expects types of the cluster dimensions must be the same";
1236  }
1237 
1238  return success();
1239 }
1240 
1241 static ParseResult
1243  std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1244  Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1245  if (succeeded(parser.parseOptionalColon())) {
1246  if (parser.parseType(dimTy))
1247  return failure();
1248  } else {
1249  dimTy = IndexType::get(parser.getContext());
1250  }
1251  if (clusterValue.has_value()) {
1252  clusterXTy = clusterYTy = clusterZTy = dimTy;
1253  }
1254  return success();
1255 }
1256 
1257 static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1258  Value clusterValue, Type clusterXTy,
1259  Type clusterYTy, Type clusterZTy) {
1260  if (!dimTy.isIndex())
1261  printer << ": " << dimTy;
1262 }
1263 
1264 static ParseResult parseLaunchFuncOperands(
1265  OpAsmParser &parser,
1267  SmallVectorImpl<Type> &argTypes) {
1268  if (parser.parseOptionalKeyword("args"))
1269  return success();
1270 
1271  auto parseElement = [&]() -> ParseResult {
1272  return failure(parser.parseOperand(argNames.emplace_back()) ||
1273  parser.parseColonType(argTypes.emplace_back()));
1274  };
1275 
1277  parseElement, " in argument list");
1278 }
1279 
1281  OperandRange operands, TypeRange types) {
1282  if (operands.empty())
1283  return;
1284  printer << "args(";
1285  llvm::interleaveComma(llvm::zip(operands, types), printer,
1286  [&](const auto &pair) {
1287  printer.printOperand(std::get<0>(pair));
1288  printer << " : ";
1289  printer.printType(std::get<1>(pair));
1290  });
1291  printer << ")";
1292 }
1293 
1294 //===----------------------------------------------------------------------===//
1295 // ShuffleOp
1296 //===----------------------------------------------------------------------===//
1297 
1298 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1299  int32_t offset, int32_t width, ShuffleMode mode) {
1300  build(builder, result, value,
1301  builder.create<arith::ConstantOp>(result.location,
1302  builder.getI32IntegerAttr(offset)),
1303  builder.create<arith::ConstantOp>(result.location,
1304  builder.getI32IntegerAttr(width)),
1305  mode);
1306 }
1307 
1308 //===----------------------------------------------------------------------===//
1309 // BarrierOp
1310 //===----------------------------------------------------------------------===//
1311 
1312 namespace {
1313 
1314 /// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1315 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1316  PatternRewriter &rewriter) {
1317  if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1318  rewriter.eraseOp(op);
1319  return success();
1320  }
1321  return failure();
1322 }
1323 
1324 } // end anonymous namespace
1325 
1326 void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1327  MLIRContext *context) {
1328  results.add(eraseRedundantGpuBarrierOps);
1329 }
1330 
1331 //===----------------------------------------------------------------------===//
1332 // GPUFuncOp
1333 //===----------------------------------------------------------------------===//
1334 
1335 /// Adds a new block argument that corresponds to buffers located in
1336 /// workgroup memory.
1337 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1338  auto attrName = getNumWorkgroupAttributionsAttrName();
1339  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1340  (*this)->setAttr(attrName,
1341  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1342  return getBody().insertArgument(
1343  getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1344 }
1345 
1346 /// Adds a new block argument that corresponds to buffers located in
1347 /// private memory.
1348 BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1349  // Buffers on the private memory always come after buffers on the workgroup
1350  // memory.
1351  return getBody().addArgument(type, loc);
1352 }
1353 
1354 void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1355  StringRef name, FunctionType type,
1356  TypeRange workgroupAttributions,
1357  TypeRange privateAttributions,
1358  ArrayRef<NamedAttribute> attrs) {
1359  OpBuilder::InsertionGuard g(builder);
1360 
1362  builder.getStringAttr(name));
1363  result.addAttribute(getFunctionTypeAttrName(result.name),
1364  TypeAttr::get(type));
1365  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1366  builder.getI64IntegerAttr(workgroupAttributions.size()));
1367  result.addAttributes(attrs);
1368  Region *body = result.addRegion();
1369  Block *entryBlock = builder.createBlock(body);
1370 
1371  // TODO: Allow passing in proper locations here.
1372  for (Type argTy : type.getInputs())
1373  entryBlock->addArgument(argTy, result.location);
1374  for (Type argTy : workgroupAttributions)
1375  entryBlock->addArgument(argTy, result.location);
1376  for (Type argTy : privateAttributions)
1377  entryBlock->addArgument(argTy, result.location);
1378 }
1379 
1380 /// Parses a GPU function memory attribution.
1381 ///
1382 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1383 /// (`private` `(` ssa-id-and-type-list `)`)?
1384 ///
1385 /// Note that this function parses only one of the two similar parts, with the
1386 /// keyword provided as argument.
1387 static ParseResult
1388 parseAttributions(OpAsmParser &parser, StringRef keyword,
1390  Attribute &attributionAttrs) {
1391  // If we could not parse the keyword, just assume empty list and succeed.
1392  if (failed(parser.parseOptionalKeyword(keyword)))
1393  return success();
1394 
1395  size_t existingArgs = args.size();
1396  ParseResult result =
1398  /*allowType=*/true, /*allowAttrs=*/true);
1399  if (failed(result))
1400  return result;
1401 
1402  bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1403  [](const OpAsmParser::Argument &arg) -> bool {
1404  return arg.attrs && !arg.attrs.empty();
1405  });
1406  if (!hadAttrs) {
1407  attributionAttrs = nullptr;
1408  return result;
1409  }
1410 
1411  Builder &builder = parser.getBuilder();
1412  SmallVector<Attribute> attributionAttrsVec;
1413  for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1414  if (!argument.attrs)
1415  attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1416  else
1417  attributionAttrsVec.push_back(argument.attrs);
1418  }
1419  attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1420  return result;
1421 }
1422 
1423 /// Parses a GPU function.
1424 ///
1425 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1426 /// (`->` function-result-list)? memory-attribution `kernel`?
1427 /// function-attributes? region
1428 ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
1430  SmallVector<DictionaryAttr> resultAttrs;
1431  SmallVector<Type> resultTypes;
1432  bool isVariadic;
1433 
1434  // Parse the function name.
1435  StringAttr nameAttr;
1436  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1437  result.attributes))
1438  return failure();
1439 
1440  auto signatureLocation = parser.getCurrentLocation();
1442  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1443  resultAttrs)))
1444  return failure();
1445 
1446  if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1447  return parser.emitError(signatureLocation)
1448  << "gpu.func requires named arguments";
1449 
1450  // Construct the function type. More types will be added to the region, but
1451  // not to the function type.
1452  Builder &builder = parser.getBuilder();
1453 
1454  SmallVector<Type> argTypes;
1455  for (auto &arg : entryArgs)
1456  argTypes.push_back(arg.type);
1457  auto type = builder.getFunctionType(argTypes, resultTypes);
1458  result.addAttribute(getFunctionTypeAttrName(result.name),
1459  TypeAttr::get(type));
1460 
1462  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1463  getResAttrsAttrName(result.name));
1464 
1465  Attribute workgroupAttributionAttrs;
1466  // Parse workgroup memory attributions.
1467  if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1468  entryArgs, workgroupAttributionAttrs)))
1469  return failure();
1470 
1471  // Store the number of operands we just parsed as the number of workgroup
1472  // memory attributions.
1473  unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1474  result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1475  builder.getI64IntegerAttr(numWorkgroupAttrs));
1476  if (workgroupAttributionAttrs)
1477  result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1478  workgroupAttributionAttrs);
1479 
1480  Attribute privateAttributionAttrs;
1481  // Parse private memory attributions.
1482  if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1483  entryArgs, privateAttributionAttrs)))
1484  return failure();
1485  if (privateAttributionAttrs)
1486  result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1487  privateAttributionAttrs);
1488 
1489  // Parse the kernel attribute if present.
1490  if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1491  result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1492  builder.getUnitAttr());
1493 
1494  // Parse attributes.
1495  if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1496  return failure();
1497 
1498  // Parse the region. If no argument names were provided, take all names
1499  // (including those of attributions) from the entry block.
1500  auto *body = result.addRegion();
1501  return parser.parseRegion(*body, entryArgs);
1502 }
1503 
1504 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
1505  ArrayRef<BlockArgument> values,
1506  ArrayAttr attributes) {
1507  if (values.empty())
1508  return;
1509 
1510  p << ' ' << keyword << '(';
1511  llvm::interleaveComma(
1512  llvm::enumerate(values), p, [&p, attributes](auto pair) {
1513  BlockArgument v = pair.value();
1514  p << v << " : " << v.getType();
1515 
1516  size_t attributionIndex = pair.index();
1517  DictionaryAttr attrs;
1518  if (attributes && attributionIndex < attributes.size())
1519  attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1520  if (attrs)
1521  p.printOptionalAttrDict(attrs.getValue());
1522  });
1523  p << ')';
1524 }
1525 
1526 void GPUFuncOp::print(OpAsmPrinter &p) {
1527  p << ' ';
1528  p.printSymbolName(getName());
1529 
1530  FunctionType type = getFunctionType();
1531  function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1532  /*isVariadic=*/false,
1533  type.getResults());
1534 
1535  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1536  getWorkgroupAttribAttrs().value_or(nullptr));
1537  printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1538  getPrivateAttribAttrs().value_or(nullptr));
1539  if (isKernel())
1540  p << ' ' << getKernelKeyword();
1541 
1543  p, *this,
1544  {getNumWorkgroupAttributionsAttrName(),
1545  GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1546  getArgAttrsAttrName(), getResAttrsAttrName(),
1547  getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1548  p << ' ';
1549  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1550 }
1551 
1552 static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1553  StringAttr attrName) {
1554  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1555  if (!allAttrs || index >= allAttrs.size())
1556  return DictionaryAttr();
1557  return llvm::cast<DictionaryAttr>(allAttrs[index]);
1558 }
1559 
1560 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1561  return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1562 }
1563 
1564 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1565  return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1566 }
1567 
1568 static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1569  DictionaryAttr value, StringAttr attrName) {
1570  MLIRContext *ctx = op.getContext();
1571  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1572  SmallVector<Attribute> elements;
1573  if (allAttrs)
1574  elements.append(allAttrs.begin(), allAttrs.end());
1575  while (elements.size() <= index)
1576  elements.push_back(DictionaryAttr::get(ctx));
1577  if (!value)
1578  elements[index] = DictionaryAttr::get(ctx);
1579  else
1580  elements[index] = value;
1581  ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1582  op->setAttr(attrName, newValue);
1583 }
1584 
1585 void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1586  DictionaryAttr value) {
1587  setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1588 }
1589 
1590 void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1591  DictionaryAttr value) {
1592  setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1593 }
1594 
1595 static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1596  StringAttr name, StringAttr attrsName) {
1597  DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1598  if (!dict)
1599  return Attribute();
1600  return dict.get(name);
1601 }
1602 
1603 Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1604  StringAttr name) {
1605  assert(index < getNumWorkgroupAttributions() &&
1606  "index must map to a workgroup attribution");
1607  return getAttributionAttr(*this, index, name,
1608  getWorkgroupAttribAttrsAttrName());
1609 }
1610 
1611 Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1612  StringAttr name) {
1613  assert(index < getNumPrivateAttributions() &&
1614  "index must map to a private attribution");
1615  return getAttributionAttr(*this, index, name,
1616  getPrivateAttribAttrsAttrName());
1617 }
1618 
1619 static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1620  Attribute value, StringAttr attrsName) {
1621  MLIRContext *ctx = op.getContext();
1623  DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1624  if (oldDict)
1625  elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1626 
1627  bool found = false;
1628  bool mustSort = true;
1629  for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1630  if (elems[i].getName() == name) {
1631  found = true;
1632  if (!value) {
1633  std::swap(elems[i], elems[elems.size() - 1]);
1634  elems.pop_back();
1635  } else {
1636  mustSort = false;
1637  elems[i] = NamedAttribute(elems[i].getName(), value);
1638  }
1639  break;
1640  }
1641  }
1642  if (!found) {
1643  if (!value)
1644  return;
1645  elems.emplace_back(name, value);
1646  }
1647  if (mustSort) {
1648  DictionaryAttr::sortInPlace(elems);
1649  }
1650  auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1651  setAttributionAttrs(op, index, newDict, attrsName);
1652 }
1653 
1654 void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1655  Attribute value) {
1656  assert(index < getNumWorkgroupAttributions() &&
1657  "index must map to a workgroup attribution");
1658  setAttributionAttr(*this, index, name, value,
1659  getWorkgroupAttribAttrsAttrName());
1660 }
1661 
1662 void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1663  Attribute value) {
1664  assert(index < getNumPrivateAttributions() &&
1665  "index must map to a private attribution");
1666  setAttributionAttr(*this, index, name, value,
1667  getPrivateAttribAttrsAttrName());
1668 }
1669 
1670 LogicalResult GPUFuncOp::verifyType() {
1671  if (isKernel() && getFunctionType().getNumResults() != 0)
1672  return emitOpError() << "expected void return type for kernel function";
1673 
1674  return success();
1675 }
1676 
1677 /// Verifies the body of the function.
1678 LogicalResult GPUFuncOp::verifyBody() {
1679  if (empty())
1680  return emitOpError() << "expected body with at least one block";
1681  unsigned numFuncArguments = getNumArguments();
1682  unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1683  unsigned numBlockArguments = front().getNumArguments();
1684  if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1685  return emitOpError() << "expected at least "
1686  << numFuncArguments + numWorkgroupAttributions
1687  << " arguments to body region";
1688 
1689  ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1690  for (unsigned i = 0; i < numFuncArguments; ++i) {
1691  Type blockArgType = front().getArgument(i).getType();
1692  if (funcArgTypes[i] != blockArgType)
1693  return emitOpError() << "expected body region argument #" << i
1694  << " to be of type " << funcArgTypes[i] << ", got "
1695  << blockArgType;
1696  }
1697 
1698  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1699  GPUDialect::getWorkgroupAddressSpace())) ||
1700  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1701  GPUDialect::getPrivateAddressSpace())))
1702  return failure();
1703 
1704  return success();
1705 }
1706 
1707 //===----------------------------------------------------------------------===//
1708 // ReturnOp
1709 //===----------------------------------------------------------------------===//
1710 
1711 LogicalResult gpu::ReturnOp::verify() {
1712  GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1713 
1714  FunctionType funType = function.getFunctionType();
1715 
1716  if (funType.getNumResults() != getOperands().size())
1717  return emitOpError()
1718  .append("expected ", funType.getNumResults(), " result operands")
1719  .attachNote(function.getLoc())
1720  .append("return type declared here");
1721 
1722  for (const auto &pair : llvm::enumerate(
1723  llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1724  auto [type, operand] = pair.value();
1725  if (type != operand.getType())
1726  return emitOpError() << "unexpected type `" << operand.getType()
1727  << "' for operand #" << pair.index();
1728  }
1729  return success();
1730 }
1731 
1732 //===----------------------------------------------------------------------===//
1733 // GPUModuleOp
1734 //===----------------------------------------------------------------------===//
1735 
1736 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1737  StringRef name, ArrayAttr targets,
1738  Attribute offloadingHandler) {
1739  ensureTerminator(*result.addRegion(), builder, result.location);
1740 
1741  Properties &props = result.getOrAddProperties<Properties>();
1742  if (targets)
1743  props.targets = targets;
1744  props.setSymName(builder.getStringAttr(name));
1745  props.offloadingHandler = offloadingHandler;
1746 }
1747 
1748 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1749  StringRef name, ArrayRef<Attribute> targets,
1750  Attribute offloadingHandler) {
1751  build(builder, result, name,
1752  targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1753  offloadingHandler);
1754 }
1755 
1756 ParseResult GPUModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1757  StringAttr nameAttr;
1758  ArrayAttr targetsAttr;
1759 
1760  if (parser.parseSymbolName(nameAttr))
1761  return failure();
1762 
1763  Properties &props = result.getOrAddProperties<Properties>();
1764  props.setSymName(nameAttr);
1765 
1766  // Parse the optional offloadingHandler
1767  if (succeeded(parser.parseOptionalLess())) {
1768  if (parser.parseAttribute(props.offloadingHandler))
1769  return failure();
1770  if (parser.parseGreater())
1771  return failure();
1772  }
1773 
1774  // Parse the optional array of target attributes.
1775  OptionalParseResult targetsAttrResult =
1776  parser.parseOptionalAttribute(targetsAttr, Type{});
1777  if (targetsAttrResult.has_value()) {
1778  if (failed(*targetsAttrResult)) {
1779  return failure();
1780  }
1781  props.targets = targetsAttr;
1782  }
1783 
1784  // If module attributes are present, parse them.
1785  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1786  return failure();
1787 
1788  // Parse the module body.
1789  auto *body = result.addRegion();
1790  if (parser.parseRegion(*body, {}))
1791  return failure();
1792 
1793  // Ensure that this module has a valid terminator.
1794  GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
1795  return success();
1796 }
1797 
1799  p << ' ';
1800  p.printSymbolName(getName());
1801 
1802  if (Attribute attr = getOffloadingHandlerAttr()) {
1803  p << " <";
1804  p.printAttribute(attr);
1805  p << ">";
1806  }
1807 
1808  if (Attribute attr = getTargetsAttr()) {
1809  p << ' ';
1810  p.printAttribute(attr);
1811  p << ' ';
1812  }
1813 
1814  p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
1815  {mlir::SymbolTable::getSymbolAttrName(),
1816  getTargetsAttrName(),
1817  getOffloadingHandlerAttrName()});
1818  p << ' ';
1819  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1820  /*printBlockTerminators=*/false);
1821 }
1822 
1823 bool GPUModuleOp::hasTarget(Attribute target) {
1824  if (ArrayAttr targets = getTargetsAttr())
1825  return llvm::count(targets.getValue(), target);
1826  return false;
1827 }
1828 
1829 void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1830  ArrayAttr &targetsAttr = getProperties().targets;
1831  SmallVector<Attribute> targetsVector(targets);
1832  targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1833 }
1834 
1835 //===----------------------------------------------------------------------===//
1836 // GPUBinaryOp
1837 //===----------------------------------------------------------------------===//
1838 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1839  Attribute offloadingHandler, ArrayAttr objects) {
1840  auto &properties = result.getOrAddProperties<Properties>();
1841  result.attributes.push_back(builder.getNamedAttr(
1842  SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1843  properties.objects = objects;
1844  if (offloadingHandler)
1845  properties.offloadingHandler = offloadingHandler;
1846  else
1847  properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
1848 }
1849 
1850 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1851  Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1852  build(builder, result, name, offloadingHandler,
1853  objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
1854 }
1855 
1856 static ParseResult parseOffloadingHandler(OpAsmParser &parser,
1857  Attribute &offloadingHandler) {
1858  if (succeeded(parser.parseOptionalLess())) {
1859  if (parser.parseAttribute(offloadingHandler))
1860  return failure();
1861  if (parser.parseGreater())
1862  return failure();
1863  }
1864  if (!offloadingHandler)
1865  offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
1866  return success();
1867 }
1868 
1870  Attribute offloadingHandler) {
1871  if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
1872  printer << '<' << offloadingHandler << '>';
1873 }
1874 
1875 //===----------------------------------------------------------------------===//
1876 // GPUMemcpyOp
1877 //===----------------------------------------------------------------------===//
1878 
1879 LogicalResult MemcpyOp::verify() {
1880  auto srcType = getSrc().getType();
1881  auto dstType = getDst().getType();
1882 
1883  if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1884  return emitOpError("arguments have incompatible element type");
1885 
1886  if (failed(verifyCompatibleShape(srcType, dstType)))
1887  return emitOpError("arguments have incompatible shape");
1888 
1889  return success();
1890 }
1891 
1892 namespace {
1893 
1894 /// Erases a common case of copy ops where a destination value is used only by
1895 /// the copy op, alloc and dealloc ops.
1896 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1897  using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1898 
1899  LogicalResult matchAndRewrite(MemcpyOp op,
1900  PatternRewriter &rewriter) const override {
1901  Value dest = op.getDst();
1902  Operation *destDefOp = dest.getDefiningOp();
1903  // `dest` must be defined by an op having Allocate memory effect in order to
1904  // perform the folding.
1905  if (!destDefOp ||
1906  !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1907  return failure();
1908  // We can erase `op` iff `dest` has no other use apart from its
1909  // use by `op` and dealloc ops.
1910  if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1911  return user != op &&
1912  !hasSingleEffect<MemoryEffects::Free>(user, dest);
1913  }))
1914  return failure();
1915  // We can perform the folding if and only if op has a single async
1916  // dependency and produces an async token as result, or if it does not have
1917  // any async dependency and does not produce any async token result.
1918  if (op.getAsyncDependencies().size() > 1 ||
1919  ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1920  (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1921  return failure();
1922  rewriter.replaceOp(op, op.getAsyncDependencies());
1923  return success();
1924  }
1925 };
1926 
1927 } // end anonymous namespace
1928 
1929 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1930  MLIRContext *context) {
1931  results.add<EraseTrivialCopyOp>(context);
1932 }
1933 
1934 //===----------------------------------------------------------------------===//
1935 // GPU_SubgroupMmaLoadMatrixOp
1936 //===----------------------------------------------------------------------===//
1937 
1938 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1939  auto srcType = getSrcMemref().getType();
1940  auto resType = getRes().getType();
1941  auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
1942  auto operand = resMatrixType.getOperand();
1943  auto srcMemrefType = llvm::cast<MemRefType>(srcType);
1944 
1945  if (!isLastMemrefDimUnitStride(srcMemrefType))
1946  return emitError(
1947  "expected source memref most minor dim must have unit stride");
1948 
1949  if (operand != "AOp" && operand != "BOp" && operand != "COp")
1950  return emitError("only AOp, BOp and COp can be loaded");
1951 
1952  return success();
1953 }
1954 
1955 //===----------------------------------------------------------------------===//
1956 // GPU_SubgroupMmaStoreMatrixOp
1957 //===----------------------------------------------------------------------===//
1958 
1959 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1960  auto srcType = getSrc().getType();
1961  auto dstType = getDstMemref().getType();
1962  auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
1963  auto dstMemrefType = llvm::cast<MemRefType>(dstType);
1964 
1965  if (!isLastMemrefDimUnitStride(dstMemrefType))
1966  return emitError(
1967  "expected destination memref most minor dim must have unit stride");
1968 
1969  if (srcMatrixType.getOperand() != "COp")
1970  return emitError(
1971  "expected the operand matrix being stored to have 'COp' operand type");
1972 
1973  return success();
1974 }
1975 
1976 //===----------------------------------------------------------------------===//
1977 // GPU_SubgroupMmaComputeOp
1978 //===----------------------------------------------------------------------===//
1979 
1980 LogicalResult SubgroupMmaComputeOp::verify() {
1981  enum OperandMap { A, B, C };
1982  SmallVector<MMAMatrixType, 3> opTypes;
1983  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
1984  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
1985  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
1986 
1987  if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
1988  opTypes[C].getOperand() != "COp")
1989  return emitError("operands must be in the order AOp, BOp, COp");
1990 
1991  ArrayRef<int64_t> aShape, bShape, cShape;
1992  aShape = opTypes[A].getShape();
1993  bShape = opTypes[B].getShape();
1994  cShape = opTypes[C].getShape();
1995 
1996  if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1997  bShape[1] != cShape[1])
1998  return emitError("operand shapes do not satisfy matmul constraints");
1999 
2000  return success();
2001 }
2002 
2003 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2004  SmallVectorImpl<::mlir::OpFoldResult> &results) {
2005  return memref::foldMemRefCast(*this);
2006 }
2007 
2008 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2009  SmallVectorImpl<::mlir::OpFoldResult> &results) {
2010  return memref::foldMemRefCast(*this);
2011 }
2012 
2013 //===----------------------------------------------------------------------===//
2014 // GPU_WaitOp
2015 //===----------------------------------------------------------------------===//
2016 
2017 namespace {
2018 
2019 /// Remove gpu.wait op use of gpu.wait op def without async dependencies.
2020 /// %t = gpu.wait async [] // No async dependencies.
2021 /// ... gpu.wait ... [%t, ...] // %t can be removed.
2022 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2023 public:
2024  using OpRewritePattern::OpRewritePattern;
2025 
2026  LogicalResult matchAndRewrite(WaitOp op,
2027  PatternRewriter &rewriter) const final {
2028  auto predicate = [](Value value) {
2029  auto waitOp = value.getDefiningOp<WaitOp>();
2030  return waitOp && waitOp->getNumOperands() == 0;
2031  };
2032  if (llvm::none_of(op.getAsyncDependencies(), predicate))
2033  return failure();
2034  SmallVector<Value> validOperands;
2035  for (Value operand : op->getOperands()) {
2036  if (predicate(operand))
2037  continue;
2038  validOperands.push_back(operand);
2039  }
2040  rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2041  return success();
2042  }
2043 };
2044 
2045 /// Simplify trivial gpu.wait ops for the following patterns.
2046 /// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2047 /// dependencies).
2048 /// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2049 /// %t0.
2050 /// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2051 /// dependencies nor return any token.
2052 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2053 public:
2054  using OpRewritePattern::OpRewritePattern;
2055 
2056  LogicalResult matchAndRewrite(WaitOp op,
2057  PatternRewriter &rewriter) const final {
2058  // Erase gpu.wait ops that neither have any async dependencies nor return
2059  // any async token.
2060  if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2061  rewriter.eraseOp(op);
2062  return success();
2063  }
2064  // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2065  if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2066  op.getAsyncToken()) {
2067  rewriter.replaceOp(op, op.getAsyncDependencies());
2068  return success();
2069  }
2070  // Erase %t = gpu.wait async ... ops, where %t has no uses.
2071  if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2072  rewriter.eraseOp(op);
2073  return success();
2074  }
2075  return failure();
2076  }
2077 };
2078 
2079 } // end anonymous namespace
2080 
2081 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2082  MLIRContext *context) {
2083  results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2084 }
2085 
2086 //===----------------------------------------------------------------------===//
2087 // GPU_AllocOp
2088 //===----------------------------------------------------------------------===//
2089 
2090 LogicalResult AllocOp::verify() {
2091  auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2092 
2093  if (static_cast<int64_t>(getDynamicSizes().size()) !=
2094  memRefType.getNumDynamicDims())
2095  return emitOpError("dimension operand count does not equal memref "
2096  "dynamic dimension count");
2097 
2098  unsigned numSymbols = 0;
2099  if (!memRefType.getLayout().isIdentity())
2100  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2101  if (getSymbolOperands().size() != numSymbols) {
2102  return emitOpError(
2103  "symbol operand count does not equal memref symbol count");
2104  }
2105 
2106  return success();
2107 }
2108 
2109 namespace {
2110 
2111 /// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2112 /// `memref::AllocOp`.
2113 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2114  using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2115 
2116  LogicalResult matchAndRewrite(memref::DimOp dimOp,
2117  PatternRewriter &rewriter) const override {
2118  std::optional<int64_t> index = dimOp.getConstantIndex();
2119  if (!index)
2120  return failure();
2121 
2122  auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2123  if (!memrefType || !memrefType.isDynamicDim(index.value()))
2124  return failure();
2125 
2126  auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2127  if (!alloc)
2128  return failure();
2129 
2130  Value substituteOp = *(alloc.getDynamicSizes().begin() +
2131  memrefType.getDynamicDimIndex(index.value()));
2132  rewriter.replaceOp(dimOp, substituteOp);
2133  return success();
2134  }
2135 };
2136 
2137 } // namespace
2138 
2139 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2140  MLIRContext *context) {
2141  results.add<SimplifyDimOfAllocOp>(context);
2142 }
2143 
2144 //===----------------------------------------------------------------------===//
2145 // GPU object attribute
2146 //===----------------------------------------------------------------------===//
2147 
2148 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2149  Attribute target, CompilationTarget format,
2150  StringAttr object, DictionaryAttr properties) {
2151  if (!target)
2152  return emitError() << "the target attribute cannot be null";
2153  if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2154  return success();
2155  return emitError() << "the target attribute must implement or promise the "
2156  "`gpu::TargetAttrInterface`";
2157 }
2158 
2159 namespace {
2160 LogicalResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2161  StringAttr &object) {
2162  std::optional<CompilationTarget> formatResult;
2163  StringRef enumKeyword;
2164  auto loc = odsParser.getCurrentLocation();
2165  if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2166  formatResult = CompilationTarget::Fatbin;
2167  if (!formatResult &&
2168  (formatResult =
2169  gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2170  odsParser.parseEqual())
2171  return odsParser.emitError(loc, "expected an equal sign");
2172  if (!formatResult)
2173  return odsParser.emitError(loc, "expected keyword for GPU object format");
2174  FailureOr<StringAttr> objectResult =
2175  FieldParser<StringAttr>::parse(odsParser);
2176  if (failed(objectResult))
2177  return odsParser.emitError(odsParser.getCurrentLocation(),
2178  "failed to parse GPU_ObjectAttr parameter "
2179  "'object' which is to be a `StringAttr`");
2180  format = *formatResult;
2181  object = *objectResult;
2182  return success();
2183 }
2184 
2185 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2186  StringAttr object) {
2187  if (format != CompilationTarget::Fatbin)
2188  odsParser << stringifyEnum(format) << " = ";
2189  odsParser << object;
2190 }
2191 } // namespace
2192 
2193 //===----------------------------------------------------------------------===//
2194 // GPU select object attribute
2195 //===----------------------------------------------------------------------===//
2196 
2197 LogicalResult
2198 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2199  Attribute target) {
2200  // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2201  if (target) {
2202  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2203  if (intAttr.getInt() < 0) {
2204  return emitError() << "the object index must be positive";
2205  }
2206  } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2207  return emitError()
2208  << "the target attribute must be a GPU Target attribute";
2209  }
2210  }
2211  return success();
2212 }
2213 
2214 //===----------------------------------------------------------------------===//
2215 // DynamicSharedMemoryOp
2216 //===----------------------------------------------------------------------===//
2217 
2218 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2219  if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2220  return emitOpError() << "must be inside an op with symbol table";
2221 
2222  MemRefType memrefType = getResultMemref().getType();
2223  // Check address space
2224  if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2225  return emitOpError() << "address space must be "
2226  << gpu::AddressSpaceAttr::getMnemonic() << "<"
2227  << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2228  }
2229  if (memrefType.hasStaticShape()) {
2230  return emitOpError() << "result memref type must be memref<?xi8, "
2231  "#gpu.address_space<workgroup>>";
2232  }
2233  return success();
2234 }
2235 
2236 //===----------------------------------------------------------------------===//
2237 // GPU target options
2238 //===----------------------------------------------------------------------===//
2239 
2240 TargetOptions::TargetOptions(
2241  StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2242  StringRef cmdOptions, CompilationTarget compilationTarget,
2243  function_ref<SymbolTable *()> getSymbolTableCallback)
2244  : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
2245  cmdOptions, compilationTarget, getSymbolTableCallback) {}
2246 
2247 TargetOptions::TargetOptions(
2248  TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2249  StringRef cmdOptions, CompilationTarget compilationTarget,
2250  function_ref<SymbolTable *()> getSymbolTableCallback)
2251  : toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
2252  cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
2253  getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
2254 
2255 TypeID TargetOptions::getTypeID() const { return typeID; }
2256 
2257 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2258 
2259 ArrayRef<std::string> TargetOptions::getLinkFiles() const { return linkFiles; }
2260 
2261 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2262 
2263 SymbolTable *TargetOptions::getSymbolTable() const {
2264  return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2265 }
2266 
2267 CompilationTarget TargetOptions::getCompilationTarget() const {
2268  return compilationTarget;
2269 }
2270 
2271 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2272  return CompilationTarget::Fatbin;
2273 }
2274 
2275 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2276 TargetOptions::tokenizeCmdOptions() const {
2277  std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2278  llvm::StringSaver stringSaver(options.first);
2279  StringRef opts = cmdOptions;
2280  // For a correct tokenization of the command line options `opts` must be
2281  // unquoted, otherwise the tokenization function returns a single string: the
2282  // unquoted `cmdOptions` -which is not the desired behavior.
2283  // Remove any quotes if they are at the beginning and end of the string:
2284  if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2285  opts.consume_front("\""), opts.consume_back("\"");
2286  if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2287  opts.consume_front("'"), opts.consume_back("'");
2288 #ifdef _WIN32
2289  llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2290  /*MarkEOLs=*/false);
2291 #else
2292  llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2293  /*MarkEOLs=*/false);
2294 #endif // _WIN32
2295  return options;
2296 }
2297 
2299 
2300 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2301 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2302 
2303 #define GET_ATTRDEF_CLASSES
2304 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2305 
2306 #define GET_OP_CLASSES
2307 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2308 
2309 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
Definition: GPUDialect.cpp:421
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
Definition: GPUDialect.cpp:585
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
Definition: GPUDialect.cpp:437
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
Definition: GPUDialect.cpp:873
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)
Prints a GPU function memory attribution.
Definition: GPUDialect.cpp:471
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
Definition: GPUDialect.cpp:563
static std::string getSparseHandleKeyword(SparseHandleKind kind)
Definition: GPUDialect.cpp:221
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
Definition: GPUDialect.cpp:311
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
Definition: GPUDialect.cpp:598
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
Definition: GPUDialect.cpp:460
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
Definition: GPUDialect.cpp:509
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
Definition: GPUDialect.cpp:824
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
Definition: GPUDialect.cpp:483
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MINUI(lhs, rhs)
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:263
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
UnitAttr getUnitAttr()
Definition: Builders.cpp:118
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:220
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:183
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:100
IntegerType getI32Type()
Definition: Builders.cpp:87
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:132
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:94
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:273
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:277
IndexType getIndexType()
Definition: Builders.cpp:75
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition: Builders.cpp:124
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:114
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:101
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:313
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:49
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:221
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
This class helps build Operations.
Definition: Builders.h:210
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:434
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:441
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
Definition: Operation.cpp:256
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:80
bool isIndex() const
Definition: Types.cpp:57
bool isF32() const
Definition: Types.cpp:52
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:92
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:59
bool isF16() const
Definition: Types.cpp:50
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static ConcreteT get(MLIRContext *ctx, Args &&...args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:130
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Definition: GPUDialect.cpp:136
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:121
Type getElementType() const
Get elementType of a single element.
Definition: GPUDialect.cpp:140
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
Definition: GPUDialect.cpp:144
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Definition: GPUDialect.cpp:142
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
Definition: GPUDialect.cpp:127
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
Definition: GPUDialect.cpp:151
unsigned getNumDims() const
Get number of dims.
Definition: GPUDialect.cpp:134
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
Definition: GPUDialect.cpp:639
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:38