MLIR  20.0.0git
GPUDialect.cpp
Go to the documentation of this file.
1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU kernel-related dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/SymbolTable.h"
29 #include "mlir/IR/TypeUtilities.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/ErrorHandling.h"
37 #include "llvm/Support/StringSaver.h"
38 #include <cassert>
39 
40 using namespace mlir;
41 using namespace mlir::gpu;
42 
43 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
44 
45 //===----------------------------------------------------------------------===//
46 // GPU Device Mapping Attributes
47 //===----------------------------------------------------------------------===//
48 
49 int64_t GPUBlockMappingAttr::getMappingId() const {
50  return static_cast<int64_t>(getBlock());
51 }
52 
53 bool GPUBlockMappingAttr::isLinearMapping() const {
54  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
55 }
56 
57 int64_t GPUBlockMappingAttr::getRelativeIndex() const {
58  return isLinearMapping()
59  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
60  : getMappingId();
61 }
62 
63 int64_t GPUWarpgroupMappingAttr::getMappingId() const {
64  return static_cast<int64_t>(getWarpgroup());
65 }
66 
67 bool GPUWarpgroupMappingAttr::isLinearMapping() const {
68  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
69 }
70 
71 int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
72  return isLinearMapping()
73  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
74  : getMappingId();
75 }
76 
77 int64_t GPUWarpMappingAttr::getMappingId() const {
78  return static_cast<int64_t>(getWarp());
79 }
80 
81 bool GPUWarpMappingAttr::isLinearMapping() const {
82  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
83 }
84 
85 int64_t GPUWarpMappingAttr::getRelativeIndex() const {
86  return isLinearMapping()
87  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
88  : getMappingId();
89 }
90 
91 int64_t GPUThreadMappingAttr::getMappingId() const {
92  return static_cast<int64_t>(getThread());
93 }
94 
95 bool GPUThreadMappingAttr::isLinearMapping() const {
96  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
97 }
98 
99 int64_t GPUThreadMappingAttr::getRelativeIndex() const {
100  return isLinearMapping()
101  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
102  : getMappingId();
103 }
104 
105 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
106  return static_cast<int64_t>(getAddressSpace());
107 }
108 
109 bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
110  llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
111 }
112 
113 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
114  llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
115 }
116 
117 //===----------------------------------------------------------------------===//
118 // MMAMatrixType
119 //===----------------------------------------------------------------------===//
120 
122  StringRef operand) {
123  return Base::get(elementType.getContext(), shape, elementType, operand);
124 }
125 
128  ArrayRef<int64_t> shape, Type elementType,
129  StringRef operand) {
130  return Base::getChecked(emitError, elementType.getContext(), shape,
131  elementType, operand);
132 }
133 
134 unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
135 
137  return getImpl()->getShape();
138 }
139 
140 Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
141 
142 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
143 
145  return elementType.isF16() || elementType.isF32() ||
146  elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
147  elementType.isInteger(32);
148 }
149 
150 LogicalResult
152  ArrayRef<int64_t> shape, Type elementType,
153  StringRef operand) {
154  if (operand != "AOp" && operand != "BOp" && operand != "COp")
155  return emitError() << "operand expected to be one of AOp, BOp or COp";
156 
157  if (shape.size() != 2)
158  return emitError() << "MMAMatrixType must have exactly two dimensions";
159 
160  if (!MMAMatrixType::isValidElementType(elementType))
161  return emitError()
162  << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
163 
164  return success();
165 }
166 
167 //===----------------------------------------------------------------------===//
168 // GPUDialect
169 //===----------------------------------------------------------------------===//
170 
171 bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
172  if (!memorySpace)
173  return false;
174  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
175  return gpuAttr.getValue() == getWorkgroupAddressSpace();
176  return false;
177 }
178 
179 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
180  Attribute memorySpace = type.getMemorySpace();
181  return isWorkgroupMemoryAddressSpace(memorySpace);
182 }
183 
184 bool GPUDialect::isKernel(Operation *op) {
185  UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
186  return static_cast<bool>(isKernelAttr);
187 }
188 
189 namespace {
190 /// This class defines the interface for handling inlining with gpu
191 /// operations.
192 struct GPUInlinerInterface : public DialectInlinerInterface {
194 
195  /// All gpu dialect ops can be inlined.
196  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
197  return true;
198  }
199 };
200 } // namespace
201 
202 void GPUDialect::initialize() {
203  addTypes<AsyncTokenType>();
204  addTypes<MMAMatrixType>();
205  addTypes<SparseDnTensorHandleType>();
206  addTypes<SparseSpMatHandleType>();
207  addTypes<SparseSpGEMMOpHandleType>();
208  addOperations<
209 #define GET_OP_LIST
210 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
211  >();
212  addAttributes<
213 #define GET_ATTRDEF_LIST
214 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
215  >();
216  addInterfaces<GPUInlinerInterface>();
217  declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
218  TerminatorOp>();
219 }
220 
221 static std::string getSparseHandleKeyword(SparseHandleKind kind) {
222  switch (kind) {
224  return "sparse.dntensor_handle";
226  return "sparse.spmat_handle";
228  return "sparse.spgemmop_handle";
229  }
230  llvm_unreachable("unknown sparse handle kind");
231  return "";
232 }
233 
235  // Parse the main keyword for the type.
236  StringRef keyword;
237  if (parser.parseKeyword(&keyword))
238  return Type();
239  MLIRContext *context = getContext();
240 
241  // Handle 'async token' types.
242  if (keyword == "async.token")
243  return AsyncTokenType::get(context);
244 
245  if (keyword == "mma_matrix") {
246  SMLoc beginLoc = parser.getNameLoc();
247 
248  // Parse '<'.
249  if (parser.parseLess())
250  return nullptr;
251 
252  // Parse the size and elementType.
253  SmallVector<int64_t> shape;
254  Type elementType;
255  if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
256  parser.parseType(elementType))
257  return nullptr;
258 
259  // Parse ','
260  if (parser.parseComma())
261  return nullptr;
262 
263  // Parse operand.
264  std::string operand;
265  if (failed(parser.parseOptionalString(&operand)))
266  return nullptr;
267 
268  // Parse '>'.
269  if (parser.parseGreater())
270  return nullptr;
271 
273  parser.getEncodedSourceLoc(beginLoc)),
274  shape, elementType, operand);
275  }
276 
278  return SparseDnTensorHandleType::get(context);
280  return SparseSpMatHandleType::get(context);
282  return SparseSpGEMMOpHandleType::get(context);
283 
284  parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
285  return Type();
286 }
287 // TODO: print refined type here. Notice that should be corresponding to the
288 // parser
289 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
290  TypeSwitch<Type>(type)
291  .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
292  .Case<SparseDnTensorHandleType>([&](Type) {
294  })
295  .Case<SparseSpMatHandleType>(
297  .Case<SparseSpGEMMOpHandleType>([&](Type) {
299  })
300  .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
301  os << "mma_matrix<";
302  auto shape = fragTy.getShape();
303  for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
304  os << *dim << 'x';
305  os << shape.back() << 'x' << fragTy.getElementType();
306  os << ", \"" << fragTy.getOperand() << "\"" << '>';
307  })
308  .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
309 }
310 
311 static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
312  NamedAttribute attr) {
313  auto array = dyn_cast<DenseI32ArrayAttr>(attr.getValue());
314  if (!array)
315  return op->emitOpError(Twine(attr.getName()) +
316  " must be a dense i32 array");
317  if (array.size() != 3)
318  return op->emitOpError(Twine(attr.getName()) +
319  " must contain exactly 3 elements");
320  return success();
321 }
322 
323 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
324  NamedAttribute attr) {
325  if (attr.getName() == getKnownBlockSizeAttrHelper().getName())
326  return verifyKnownLaunchSizeAttr(op, attr);
327  if (attr.getName() == getKnownGridSizeAttrHelper().getName())
328  return verifyKnownLaunchSizeAttr(op, attr);
329  if (!llvm::isa<UnitAttr>(attr.getValue()) ||
330  attr.getName() != getContainerModuleAttrName())
331  return success();
332 
333  auto module = dyn_cast<ModuleOp>(op);
334  if (!module)
335  return op->emitError("expected '")
336  << getContainerModuleAttrName() << "' attribute to be attached to '"
337  << ModuleOp::getOperationName() << '\'';
338 
339  auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
340  // Ignore launches that are nested more or less deep than functions in the
341  // module we are currently checking.
342  if (!launchOp->getParentOp() ||
343  launchOp->getParentOp()->getParentOp() != module)
344  return success();
345 
346  // Ignore launch ops with missing attributes here. The errors will be
347  // reported by the verifiers of those ops.
348  if (!launchOp->getAttrOfType<SymbolRefAttr>(
349  LaunchFuncOp::getKernelAttrName(launchOp->getName())))
350  return success();
351 
352  // Check that `launch_func` refers to a well-formed GPU kernel container.
353  StringAttr kernelContainerName = launchOp.getKernelModuleName();
354  Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
355  if (!kernelContainer)
356  return launchOp.emitOpError()
357  << "kernel container '" << kernelContainerName.getValue()
358  << "' is undefined";
359 
360  // If the container is a GPU binary op return success.
361  if (isa<BinaryOp>(kernelContainer))
362  return success();
363 
364  auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
365  if (!kernelModule)
366  return launchOp.emitOpError()
367  << "kernel module '" << kernelContainerName.getValue()
368  << "' is undefined";
369 
370  // Check that `launch_func` refers to a well-formed kernel function.
371  Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
372  if (!kernelFunc)
373  return launchOp.emitOpError("kernel function '")
374  << launchOp.getKernel() << "' is undefined";
375  auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
376  if (!kernelConvertedFunction) {
377  InFlightDiagnostic diag = launchOp.emitOpError()
378  << "referenced kernel '" << launchOp.getKernel()
379  << "' is not a function";
380  diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
381  return diag;
382  }
383 
384  if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
385  GPUDialect::getKernelFuncAttrName()))
386  return launchOp.emitOpError("kernel function is missing the '")
387  << GPUDialect::getKernelFuncAttrName() << "' attribute";
388 
389  // TODO: If the kernel isn't a GPU function (which happens during separate
390  // compilation), do not check type correspondence as it would require the
391  // verifier to be aware of the type conversion.
392  auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
393  if (!kernelGPUFunction)
394  return success();
395 
396  unsigned actualNumArguments = launchOp.getNumKernelOperands();
397  unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
398  if (expectedNumArguments != actualNumArguments)
399  return launchOp.emitOpError("got ")
400  << actualNumArguments << " kernel operands but expected "
401  << expectedNumArguments;
402 
403  auto functionType = kernelGPUFunction.getFunctionType();
404  for (unsigned i = 0; i < expectedNumArguments; ++i) {
405  if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
406  return launchOp.emitOpError("type of function argument ")
407  << i << " does not match";
408  }
409  }
410 
411  return success();
412  });
413 
414  return walkResult.wasInterrupted() ? failure() : success();
415 }
416 
417 /// Parses an optional list of async operands with an optional leading keyword.
418 /// (`async`)? (`[` ssa-id-list `]`)?
419 ///
420 /// This method is used by the tablegen assembly format for async ops as well.
421 static ParseResult parseAsyncDependencies(
422  OpAsmParser &parser, Type &asyncTokenType,
424  auto loc = parser.getCurrentLocation();
425  if (succeeded(parser.parseOptionalKeyword("async"))) {
426  if (parser.getNumResults() == 0)
427  return parser.emitError(loc, "needs to be named when marked 'async'");
428  asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
429  }
430  return parser.parseOperandList(asyncDependencies,
432 }
433 
434 /// Prints optional async dependencies with its leading keyword.
435 /// (`async`)? (`[` ssa-id-list `]`)?
436 // Used by the tablegen assembly format for several async ops.
438  Type asyncTokenType,
439  OperandRange asyncDependencies) {
440  if (asyncTokenType)
441  printer << "async";
442  if (asyncDependencies.empty())
443  return;
444  if (asyncTokenType)
445  printer << ' ';
446  printer << '[';
447  llvm::interleaveComma(asyncDependencies, printer);
448  printer << ']';
449 }
450 
451 // GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
452 /// Parses a GPU function memory attribution.
453 ///
454 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
455 /// (`private` `(` ssa-id-and-type-list `)`)?
456 ///
457 /// Note that this function parses only one of the two similar parts, with the
458 /// keyword provided as argument.
459 static ParseResult
460 parseAttributions(OpAsmParser &parser, StringRef keyword,
462  // If we could not parse the keyword, just assume empty list and succeed.
463  if (failed(parser.parseOptionalKeyword(keyword)))
464  return success();
465 
467  /*allowType=*/true);
468 }
469 
470 /// Prints a GPU function memory attribution.
471 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
472  ArrayRef<BlockArgument> values) {
473  if (values.empty())
474  return;
475 
476  p << ' ' << keyword << '(';
477  llvm::interleaveComma(
478  values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); });
479  p << ')';
480 }
481 
482 /// Verifies a GPU function memory attribution.
483 static LogicalResult verifyAttributions(Operation *op,
484  ArrayRef<BlockArgument> attributions,
485  gpu::AddressSpace memorySpace) {
486  for (Value v : attributions) {
487  auto type = llvm::dyn_cast<MemRefType>(v.getType());
488  if (!type)
489  return op->emitOpError() << "expected memref type in attribution";
490 
491  // We can only verify the address space if it hasn't already been lowered
492  // from the AddressSpaceAttr to a target-specific numeric value.
493  auto addressSpace =
494  llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
495  if (!addressSpace)
496  continue;
497  if (addressSpace.getValue() != memorySpace)
498  return op->emitOpError()
499  << "expected memory space " << stringifyAddressSpace(memorySpace)
500  << " in attribution";
501  }
502  return success();
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // AllReduceOp
507 //===----------------------------------------------------------------------===//
508 
509 static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
510  Type resType) {
511  using Kind = gpu::AllReduceOperation;
512  if (llvm::is_contained(
513  {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
514  opName)) {
515  if (!isa<FloatType>(resType))
516  return failure();
517  }
518 
519  if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
520  Kind::AND, Kind::OR, Kind::XOR},
521  opName)) {
522  if (!isa<IntegerType>(resType))
523  return failure();
524  }
525 
526  return success();
527 }
528 
529 LogicalResult gpu::AllReduceOp::verifyRegions() {
530  if (getBody().empty() != getOp().has_value())
531  return emitError("expected either an op attribute or a non-empty body");
532  if (!getBody().empty()) {
533  if (getBody().getNumArguments() != 2)
534  return emitError("expected two region arguments");
535  for (auto argument : getBody().getArguments()) {
536  if (argument.getType() != getType())
537  return emitError("incorrect region argument type");
538  }
539  unsigned yieldCount = 0;
540  for (Block &block : getBody()) {
541  if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
542  if (yield.getNumOperands() != 1)
543  return emitError("expected one gpu.yield operand");
544  if (yield.getOperand(0).getType() != getType())
545  return emitError("incorrect gpu.yield type");
546  ++yieldCount;
547  }
548  }
549  if (yieldCount == 0)
550  return emitError("expected gpu.yield op in region");
551  } else {
552  gpu::AllReduceOperation opName = *getOp();
553  if (failed(verifyReduceOpAndType(opName, getType()))) {
554  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
555  << "` reduction operation is not compatible with type "
556  << getType();
557  }
558  }
559 
560  return success();
561 }
562 
564  auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
565  if (!launchOp)
566  return false;
567 
568  Region &body = launchOp.getBody();
569  assert(!body.empty() && "Invalid region");
570 
571  // Only convert ops in gpu::launch entry block for now.
572  return op->getBlock() == &body.front();
573 }
574 
575 OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
576  if (!getUniform() && canMakeGroupOpUniform(*this)) {
577  setUniform(true);
578  return getResult();
579  }
580 
581  return nullptr;
582 }
583 
584 // TODO: Support optional custom attributes (without dialect prefix).
585 static ParseResult parseAllReduceOperation(AsmParser &parser,
586  AllReduceOperationAttr &attr) {
587  StringRef enumStr;
588  if (!parser.parseOptionalKeyword(&enumStr)) {
589  std::optional<AllReduceOperation> op =
590  gpu::symbolizeAllReduceOperation(enumStr);
591  if (!op)
592  return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
593  attr = AllReduceOperationAttr::get(parser.getContext(), *op);
594  }
595  return success();
596 }
597 
598 static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
599  AllReduceOperationAttr attr) {
600  if (attr)
601  attr.print(printer);
602 }
603 
604 //===----------------------------------------------------------------------===//
605 // SubgroupReduceOp
606 //===----------------------------------------------------------------------===//
607 
608 LogicalResult gpu::SubgroupReduceOp::verify() {
609  Type elemType = getType();
610  if (auto vecTy = dyn_cast<VectorType>(elemType)) {
611  if (vecTy.isScalable())
612  return emitOpError() << "is not compatible with scalable vector types";
613 
614  elemType = vecTy.getElementType();
615  }
616 
617  gpu::AllReduceOperation opName = getOp();
618  if (failed(verifyReduceOpAndType(opName, elemType))) {
619  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
620  << "` reduction operation is not compatible with type "
621  << getType();
622  }
623 
624  auto clusterSize = getClusterSize();
625  if (clusterSize) {
626  uint32_t size = *clusterSize;
627  if (!llvm::isPowerOf2_32(size)) {
628  return emitOpError() << "cluster size " << size
629  << " is not a power of two";
630  }
631  }
632 
633  uint32_t stride = getClusterStride();
634  if (stride != 1 && !clusterSize) {
635  return emitOpError() << "cluster stride can only be specified if cluster "
636  "size is specified";
637  }
638  if (!llvm::isPowerOf2_32(stride)) {
639  return emitOpError() << "cluster stride " << stride
640  << " is not a power of two";
641  }
642 
643  return success();
644 }
645 
646 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
647  if (getClusterSize() == 1)
648  return getValue();
649 
650  if (!getUniform() && canMakeGroupOpUniform(*this)) {
651  setUniform(true);
652  return getResult();
653  }
654 
655  return nullptr;
656 }
657 
658 //===----------------------------------------------------------------------===//
659 // AsyncOpInterface
660 //===----------------------------------------------------------------------===//
661 
663  op->insertOperands(0, {token});
664  if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
665  return;
666  auto attrName =
668  auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
669 
670  // Async dependencies is the only variadic operand.
671  if (!sizeAttr)
672  return;
673 
674  SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
675  ++sizes.front();
676  op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
677 }
678 
679 //===----------------------------------------------------------------------===//
680 // LaunchOp
681 //===----------------------------------------------------------------------===//
682 
683 void LaunchOp::build(OpBuilder &builder, OperationState &result,
684  Value gridSizeX, Value gridSizeY, Value gridSizeZ,
685  Value getBlockSizeX, Value getBlockSizeY,
686  Value getBlockSizeZ, Value dynamicSharedMemorySize,
687  Type asyncTokenType, ValueRange asyncDependencies,
688  TypeRange workgroupAttributions,
689  TypeRange privateAttributions, Value clusterSizeX,
690  Value clusterSizeY, Value clusterSizeZ) {
691  OpBuilder::InsertionGuard g(builder);
692 
693  // Add a WorkGroup attribution attribute. This attribute is required to
694  // identify private attributions in the list of block argguments.
695  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
696  builder.getI64IntegerAttr(workgroupAttributions.size()));
697 
698  // Add Op operands.
699  result.addOperands(asyncDependencies);
700  if (asyncTokenType)
701  result.types.push_back(builder.getType<AsyncTokenType>());
702 
703  // Add grid and block sizes as op operands, followed by the data operands.
704  result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
705  getBlockSizeY, getBlockSizeZ});
706  if (clusterSizeX)
707  result.addOperands(clusterSizeX);
708  if (clusterSizeY)
709  result.addOperands(clusterSizeY);
710  if (clusterSizeZ)
711  result.addOperands(clusterSizeZ);
712  if (dynamicSharedMemorySize)
713  result.addOperands(dynamicSharedMemorySize);
714 
715  // Create a kernel body region with kNumConfigRegionAttributes + N memory
716  // attributions, where the first kNumConfigRegionAttributes arguments have
717  // `index` type and the rest have the same types as the data operands.
718  Region *kernelRegion = result.addRegion();
719  Block *body = builder.createBlock(kernelRegion);
720  // TODO: Allow passing in proper locations here.
721  for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
722  body->addArgument(builder.getIndexType(), result.location);
723  // Add WorkGroup & Private attributions to the region arguments.
724  for (Type argTy : workgroupAttributions)
725  body->addArgument(argTy, result.location);
726  for (Type argTy : privateAttributions)
727  body->addArgument(argTy, result.location);
728  // Fill OperandSegmentSize Attribute.
729  SmallVector<int32_t, 11> segmentSizes(11, 1);
730  segmentSizes.front() = asyncDependencies.size();
731  segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
732  segmentSizes[7] = clusterSizeX ? 1 : 0;
733  segmentSizes[8] = clusterSizeY ? 1 : 0;
734  segmentSizes[9] = clusterSizeZ ? 1 : 0;
735  result.addAttribute(getOperandSegmentSizeAttr(),
736  builder.getDenseI32ArrayAttr(segmentSizes));
737 }
738 
739 KernelDim3 LaunchOp::getBlockIds() {
740  assert(!getBody().empty() && "LaunchOp body must not be empty.");
741  auto args = getBody().getArguments();
742  return KernelDim3{args[0], args[1], args[2]};
743 }
744 
745 KernelDim3 LaunchOp::getThreadIds() {
746  assert(!getBody().empty() && "LaunchOp body must not be empty.");
747  auto args = getBody().getArguments();
748  return KernelDim3{args[3], args[4], args[5]};
749 }
750 
751 KernelDim3 LaunchOp::getGridSize() {
752  assert(!getBody().empty() && "LaunchOp body must not be empty.");
753  auto args = getBody().getArguments();
754  return KernelDim3{args[6], args[7], args[8]};
755 }
756 
758  assert(!getBody().empty() && "LaunchOp body must not be empty.");
759  auto args = getBody().getArguments();
760  return KernelDim3{args[9], args[10], args[11]};
761 }
762 
763 std::optional<KernelDim3> LaunchOp::getClusterIds() {
764  assert(!getBody().empty() && "LaunchOp body must not be empty.");
765  if (!hasClusterSize())
766  return std::nullopt;
767  auto args = getBody().getArguments();
768  return KernelDim3{args[12], args[13], args[14]};
769 }
770 
771 std::optional<KernelDim3> LaunchOp::getClusterSize() {
772  assert(!getBody().empty() && "LaunchOp body must not be empty.");
773  if (!hasClusterSize())
774  return std::nullopt;
775  auto args = getBody().getArguments();
776  return KernelDim3{args[15], args[16], args[17]};
777 }
778 
779 KernelDim3 LaunchOp::getGridSizeOperandValues() {
780  auto operands = getOperands().drop_front(getAsyncDependencies().size());
781  return KernelDim3{operands[0], operands[1], operands[2]};
782 }
783 
784 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
785  auto operands = getOperands().drop_front(getAsyncDependencies().size());
786  return KernelDim3{operands[3], operands[4], operands[5]};
787 }
788 
789 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
790  auto operands = getOperands().drop_front(getAsyncDependencies().size());
791  if (!hasClusterSize())
792  return std::nullopt;
793  return KernelDim3{operands[6], operands[7], operands[8]};
794 }
795 
796 LogicalResult LaunchOp::verify() {
797  if (!(hasClusterSize()) &&
798  (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
799  return emitOpError() << "cluster size must be all present";
800  return success();
801 }
802 
803 LogicalResult LaunchOp::verifyRegions() {
804  // Kernel launch takes kNumConfigOperands leading operands for grid/block
805  // sizes and transforms them into kNumConfigRegionAttributes region arguments
806  // for block/thread identifiers and grid/block sizes.
807  if (!getBody().empty()) {
808  if (getBody().getNumArguments() <
809  kNumConfigRegionAttributes + getNumWorkgroupAttributions())
810  return emitOpError("unexpected number of region arguments");
811  }
812 
813  // Verify Attributions Address Spaces.
814  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
815  GPUDialect::getWorkgroupAddressSpace())) ||
816  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
817  GPUDialect::getPrivateAddressSpace())))
818  return failure();
819 
820  // Block terminators without successors are expected to exit the kernel region
821  // and must be `gpu.terminator`.
822  for (Block &block : getBody()) {
823  if (block.empty())
824  continue;
825  if (block.back().getNumSuccessors() != 0)
826  continue;
827  if (!isa<gpu::TerminatorOp>(&block.back())) {
828  return block.back()
829  .emitError()
830  .append("expected '", gpu::TerminatorOp::getOperationName(),
831  "' or a terminator with successors")
832  .attachNote(getLoc())
833  .append("in '", LaunchOp::getOperationName(), "' body region");
834  }
835  }
836 
837  if (getNumResults() == 0 && getAsyncToken())
838  return emitOpError("needs to be named when async keyword is specified");
839 
840  return success();
841 }
842 
843 // Pretty-print the kernel grid/block size assignment as
844 // (%iter-x, %iter-y, %iter-z) in
845 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
846 // where %size-* and %iter-* will correspond to the body region arguments.
848  KernelDim3 operands, KernelDim3 ids) {
849  p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
850  p << size.x << " = " << operands.x << ", ";
851  p << size.y << " = " << operands.y << ", ";
852  p << size.z << " = " << operands.z << ')';
853 }
854 
855 void LaunchOp::print(OpAsmPrinter &p) {
856  if (getAsyncToken()) {
857  p << " async";
858  if (!getAsyncDependencies().empty())
859  p << " [" << getAsyncDependencies() << ']';
860  }
861  // Print the launch configuration.
862  if (hasClusterSize()) {
863  p << ' ' << getClustersKeyword();
864  printSizeAssignment(p, getClusterSize().value(),
865  getClusterSizeOperandValues().value(),
866  getClusterIds().value());
867  }
868  p << ' ' << getBlocksKeyword();
869  printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
870  getBlockIds());
871  p << ' ' << getThreadsKeyword();
872  printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
873  getThreadIds());
874  if (getDynamicSharedMemorySize())
875  p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
876  << getDynamicSharedMemorySize();
877 
878  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
879  printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
880 
881  p << ' ';
882 
883  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
884  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
885  LaunchOp::getOperandSegmentSizeAttr(),
886  getNumWorkgroupAttributionsAttrName()});
887 }
888 
889 // Parse the size assignment blocks for blocks and threads. These have the form
890 // (%region_arg, %region_arg, %region_arg) in
891 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
892 // where %region_arg are percent-identifiers for the region arguments to be
893 // introduced further (SSA defs), and %operand are percent-identifiers for the
894 // SSA value uses.
895 static ParseResult
900  assert(indices.size() == 3 && "space for three indices expected");
903  /*allowResultNumber=*/false) ||
904  parser.parseKeyword("in") || parser.parseLParen())
905  return failure();
906  std::move(args.begin(), args.end(), indices.begin());
907 
908  for (int i = 0; i < 3; ++i) {
909  if (i != 0 && parser.parseComma())
910  return failure();
911  if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
912  parser.parseEqual() || parser.parseOperand(sizes[i]))
913  return failure();
914  }
915 
916  return parser.parseRParen();
917 }
918 
919 /// Parses a Launch operation.
920 /// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
921 /// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
922 /// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
923 /// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
924 /// memory-attribution
925 /// region attr-dict?
926 /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
927 ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
928  // Sizes of the grid and block.
930  sizes(LaunchOp::kNumConfigOperands);
931 
932  // Actual (data) operands passed to the kernel.
934 
935  // Region arguments to be created.
937  LaunchOp::kNumConfigRegionAttributes);
938 
939  // Parse optional async dependencies.
941  Type asyncTokenType;
942  if (failed(
943  parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
944  parser.resolveOperands(asyncDependencies, asyncTokenType,
945  result.operands))
946  return failure();
947  if (parser.getNumResults() > 0)
948  result.types.push_back(asyncTokenType);
949 
950  bool hasCluster = false;
951  if (succeeded(
952  parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) {
953  hasCluster = true;
954  sizes.resize(9);
955  regionArgs.resize(18);
956  }
958  MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
959 
960  // Last three segment assigns the cluster size. In the region argument
961  // list, this is last 6 arguments.
962  if (hasCluster) {
963  if (parseSizeAssignment(parser, sizesRef.drop_front(6),
964  regionArgsRef.slice(15, 3),
965  regionArgsRef.slice(12, 3)))
966  return failure();
967  }
968  // Parse the size assignment segments: the first segment assigns grid sizes
969  // and defines values for block identifiers; the second segment assigns block
970  // sizes and defines values for thread identifiers. In the region argument
971  // list, identifiers precede sizes, and block-related values precede
972  // thread-related values.
973  if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
974  parseSizeAssignment(parser, sizesRef.take_front(3),
975  regionArgsRef.slice(6, 3),
976  regionArgsRef.slice(0, 3)) ||
977  parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
978  parseSizeAssignment(parser, sizesRef.drop_front(3),
979  regionArgsRef.slice(9, 3),
980  regionArgsRef.slice(3, 3)) ||
981  parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
982  result.operands))
983  return failure();
984 
985  OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
986  bool hasDynamicSharedMemorySize = false;
987  if (!parser.parseOptionalKeyword(
988  LaunchOp::getDynamicSharedMemorySizeKeyword())) {
989  hasDynamicSharedMemorySize = true;
990  if (parser.parseOperand(dynamicSharedMemorySize) ||
991  parser.resolveOperand(dynamicSharedMemorySize,
992  parser.getBuilder().getI32Type(),
993  result.operands))
994  return failure();
995  }
996 
997  // Create the region arguments, it has kNumConfigRegionAttributes arguments
998  // that correspond to block/thread identifiers and grid/block sizes, all
999  // having `index` type, a variadic number of WorkGroup Attributions and
1000  // a variadic number of Private Attributions. The number of WorkGroup
1001  // Attributions is stored in the attr with name:
1002  // LaunchOp::getNumWorkgroupAttributionsAttrName().
1003  Type index = parser.getBuilder().getIndexType();
1005  LaunchOp::kNumConfigRegionAttributes + 6, index);
1006 
1007  SmallVector<OpAsmParser::Argument> regionArguments;
1008  for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1010  arg.ssaName = std::get<0>(ssaValueAndType);
1011  arg.type = std::get<1>(ssaValueAndType);
1012  regionArguments.push_back(arg);
1013  }
1014 
1015  Builder &builder = parser.getBuilder();
1016  // Parse workgroup memory attributions.
1017  if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
1018  regionArguments)))
1019  return failure();
1020 
1021  // Store the number of operands we just parsed as the number of workgroup
1022  // memory attributions.
1023  unsigned numWorkgroupAttrs = regionArguments.size() -
1024  LaunchOp::kNumConfigRegionAttributes -
1025  (hasCluster ? 6 : 0);
1026  result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1027  builder.getI64IntegerAttr(numWorkgroupAttrs));
1028 
1029  // Parse private memory attributions.
1030  if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
1031  regionArguments)))
1032  return failure();
1033 
1034  // Introduce the body region and parse it. The region has
1035  // kNumConfigRegionAttributes arguments that correspond to
1036  // block/thread identifiers and grid/block sizes, all having `index` type.
1037  Region *body = result.addRegion();
1038  if (parser.parseRegion(*body, regionArguments) ||
1039  parser.parseOptionalAttrDict(result.attributes))
1040  return failure();
1041 
1042  SmallVector<int32_t, 11> segmentSizes(11, 1);
1043  segmentSizes.front() = asyncDependencies.size();
1044 
1045  if (!hasCluster) {
1046  segmentSizes[7] = 0;
1047  segmentSizes[8] = 0;
1048  segmentSizes[9] = 0;
1049  }
1050  segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1051  result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1052  parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1053  return success();
1054 }
1055 
1056 /// Simplify the gpu.launch when the range of a thread or block ID is
1057 /// trivially known to be one.
1058 struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1060  LogicalResult matchAndRewrite(LaunchOp op,
1061  PatternRewriter &rewriter) const override {
1062  // If the range implies a single value for `id`, replace `id`'s uses by
1063  // zero.
1064  Value zero;
1065  bool simplified = false;
1066  auto constPropIdUses = [&](Value id, Value size) {
1067  // Check if size is trivially one.
1068  if (!matchPattern(size, m_One()))
1069  return;
1070  if (id.getUses().empty())
1071  return;
1072  if (!simplified) {
1073  // Create a zero value the first time.
1074  OpBuilder::InsertionGuard guard(rewriter);
1075  rewriter.setInsertionPointToStart(&op.getBody().front());
1076  zero =
1077  rewriter.create<arith::ConstantIndexOp>(op.getLoc(), /*value=*/0);
1078  }
1079  rewriter.replaceAllUsesWith(id, zero);
1080  simplified = true;
1081  };
1082  constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1083  constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1084  constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1085  constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1086  constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1087  constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1088 
1089  return success(simplified);
1090  }
1091 };
1092 
1093 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1094  MLIRContext *context) {
1095  rewrites.add<FoldLaunchArguments>(context);
1096 }
1097 
1098 /// Adds a new block argument that corresponds to buffers located in
1099 /// workgroup memory.
1100 BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1101  auto attrName = getNumWorkgroupAttributionsAttrName();
1102  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1103  (*this)->setAttr(attrName,
1104  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1105  return getBody().insertArgument(
1106  LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1107 }
1108 
1109 /// Adds a new block argument that corresponds to buffers located in
1110 /// private memory.
1111 BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1112  // Buffers on the private memory always come after buffers on the workgroup
1113  // memory.
1114  return getBody().addArgument(type, loc);
1115 }
1116 
1117 //===----------------------------------------------------------------------===//
1118 // LaunchFuncOp
1119 //===----------------------------------------------------------------------===//
1120 
1121 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1122  SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1123  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1124  ValueRange kernelOperands, Type asyncTokenType,
1125  ValueRange asyncDependencies,
1126  std::optional<KernelDim3> clusterSize) {
1127  assert(kernelSymbol.getNestedReferences().size() == 1 &&
1128  "expected a symbol reference with a single nested reference");
1129  result.addOperands(asyncDependencies);
1130  if (asyncTokenType)
1131  result.types.push_back(builder.getType<AsyncTokenType>());
1132 
1133  // Add grid and block sizes as op operands, followed by the data operands.
1134  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1135  getBlockSize.y, getBlockSize.z});
1136  if (clusterSize.has_value())
1137  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1138  if (dynamicSharedMemorySize)
1139  result.addOperands(dynamicSharedMemorySize);
1140  result.addOperands(kernelOperands);
1141 
1142  Properties &prop = result.getOrAddProperties<Properties>();
1143  prop.kernel = kernelSymbol;
1144  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1145  // Initialize the segment sizes to 1.
1146  for (auto &sz : prop.operandSegmentSizes)
1147  sz = 1;
1148  prop.operandSegmentSizes[0] = asyncDependencies.size();
1149  if (!clusterSize.has_value()) {
1150  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1151  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1152  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1153  }
1154  prop.operandSegmentSizes[segmentSizesLen - 3] =
1155  dynamicSharedMemorySize ? 1 : 0;
1156  prop.operandSegmentSizes[segmentSizesLen - 2] =
1157  static_cast<int32_t>(kernelOperands.size());
1158  prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1159 }
1160 
1161 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1162  GPUFuncOp kernelFunc, KernelDim3 gridSize,
1163  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1164  ValueRange kernelOperands, Type asyncTokenType,
1165  ValueRange asyncDependencies,
1166  std::optional<KernelDim3> clusterSize) {
1167  auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1168  auto kernelSymbol =
1169  SymbolRefAttr::get(kernelModule.getNameAttr(),
1170  {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1171  build(builder, result, kernelSymbol, gridSize, getBlockSize,
1172  dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1173  asyncDependencies, clusterSize);
1174 }
1175 
1176 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1177  SymbolRefAttr kernel, KernelDim3 gridSize,
1178  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1179  ValueRange kernelOperands, Value asyncObject,
1180  std::optional<KernelDim3> clusterSize) {
1181  // Add grid and block sizes as op operands, followed by the data operands.
1182  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1183  getBlockSize.y, getBlockSize.z});
1184  if (clusterSize.has_value())
1185  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1186  if (dynamicSharedMemorySize)
1187  result.addOperands(dynamicSharedMemorySize);
1188  result.addOperands(kernelOperands);
1189  if (asyncObject)
1190  result.addOperands(asyncObject);
1191  Properties &prop = result.getOrAddProperties<Properties>();
1192  prop.kernel = kernel;
1193  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1194  // Initialize the segment sizes to 1.
1195  for (auto &sz : prop.operandSegmentSizes)
1196  sz = 1;
1197  prop.operandSegmentSizes[0] = 0;
1198  if (!clusterSize.has_value()) {
1199  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1200  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1201  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1202  }
1203  prop.operandSegmentSizes[segmentSizesLen - 3] =
1204  dynamicSharedMemorySize ? 1 : 0;
1205  prop.operandSegmentSizes[segmentSizesLen - 2] =
1206  static_cast<int32_t>(kernelOperands.size());
1207  prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1208 }
1209 
1210 StringAttr LaunchFuncOp::getKernelModuleName() {
1211  return getKernel().getRootReference();
1212 }
1213 
1214 StringAttr LaunchFuncOp::getKernelName() {
1215  return getKernel().getLeafReference();
1216 }
1217 
1218 unsigned LaunchFuncOp::getNumKernelOperands() {
1219  return getKernelOperands().size();
1220 }
1221 
1222 Value LaunchFuncOp::getKernelOperand(unsigned i) {
1223  return getKernelOperands()[i];
1224 }
1225 
1226 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1227  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1228  return KernelDim3{operands[0], operands[1], operands[2]};
1229 }
1230 
1231 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1232  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1233  return KernelDim3{operands[3], operands[4], operands[5]};
1234 }
1235 
1236 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1237  assert(hasClusterSize() &&
1238  "cluster size is not set, check hasClusterSize() first");
1239  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1240  return KernelDim3{operands[6], operands[7], operands[8]};
1241 }
1242 
1243 LogicalResult LaunchFuncOp::verify() {
1244  auto module = (*this)->getParentOfType<ModuleOp>();
1245  if (!module)
1246  return emitOpError("expected to belong to a module");
1247 
1248  if (!module->getAttrOfType<UnitAttr>(
1249  GPUDialect::getContainerModuleAttrName()))
1250  return emitOpError("expected the closest surrounding module to have the '" +
1251  GPUDialect::getContainerModuleAttrName() +
1252  "' attribute");
1253 
1254  if (hasClusterSize()) {
1255  if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1256  getClusterSizeZ().getType() != getClusterSizeX().getType())
1257  return emitOpError()
1258  << "expects types of the cluster dimensions must be the same";
1259  }
1260 
1261  return success();
1262 }
1263 
1264 static ParseResult
1266  std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1267  Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1268  if (succeeded(parser.parseOptionalColon())) {
1269  if (parser.parseType(dimTy))
1270  return failure();
1271  } else {
1272  dimTy = IndexType::get(parser.getContext());
1273  }
1274  if (clusterValue.has_value()) {
1275  clusterXTy = clusterYTy = clusterZTy = dimTy;
1276  }
1277  return success();
1278 }
1279 
1280 static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1281  Value clusterValue, Type clusterXTy,
1282  Type clusterYTy, Type clusterZTy) {
1283  if (!dimTy.isIndex())
1284  printer << ": " << dimTy;
1285 }
1286 
1287 static ParseResult parseLaunchFuncOperands(
1288  OpAsmParser &parser,
1290  SmallVectorImpl<Type> &argTypes) {
1291  if (parser.parseOptionalKeyword("args"))
1292  return success();
1293 
1294  auto parseElement = [&]() -> ParseResult {
1295  return failure(parser.parseOperand(argNames.emplace_back()) ||
1296  parser.parseColonType(argTypes.emplace_back()));
1297  };
1298 
1300  parseElement, " in argument list");
1301 }
1302 
1304  OperandRange operands, TypeRange types) {
1305  if (operands.empty())
1306  return;
1307  printer << "args(";
1308  llvm::interleaveComma(llvm::zip(operands, types), printer,
1309  [&](const auto &pair) {
1310  printer.printOperand(std::get<0>(pair));
1311  printer << " : ";
1312  printer.printType(std::get<1>(pair));
1313  });
1314  printer << ")";
1315 }
1316 
1317 //===----------------------------------------------------------------------===//
1318 // ShuffleOp
1319 //===----------------------------------------------------------------------===//
1320 
1321 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1322  int32_t offset, int32_t width, ShuffleMode mode) {
1323  build(builder, result, value,
1324  builder.create<arith::ConstantOp>(result.location,
1325  builder.getI32IntegerAttr(offset)),
1326  builder.create<arith::ConstantOp>(result.location,
1327  builder.getI32IntegerAttr(width)),
1328  mode);
1329 }
1330 
1331 //===----------------------------------------------------------------------===//
1332 // BarrierOp
1333 //===----------------------------------------------------------------------===//
1334 
1335 namespace {
1336 
1337 /// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1338 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1339  PatternRewriter &rewriter) {
1340  if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1341  rewriter.eraseOp(op);
1342  return success();
1343  }
1344  return failure();
1345 }
1346 
1347 } // end anonymous namespace
1348 
1349 void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1350  MLIRContext *context) {
1351  results.add(eraseRedundantGpuBarrierOps);
1352 }
1353 
1354 //===----------------------------------------------------------------------===//
1355 // GPUFuncOp
1356 //===----------------------------------------------------------------------===//
1357 
1358 /// Adds a new block argument that corresponds to buffers located in
1359 /// workgroup memory.
1360 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1361  auto attrName = getNumWorkgroupAttributionsAttrName();
1362  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1363  (*this)->setAttr(attrName,
1364  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1365  return getBody().insertArgument(
1366  getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1367 }
1368 
1369 /// Adds a new block argument that corresponds to buffers located in
1370 /// private memory.
1371 BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1372  // Buffers on the private memory always come after buffers on the workgroup
1373  // memory.
1374  return getBody().addArgument(type, loc);
1375 }
1376 
1377 void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1378  StringRef name, FunctionType type,
1379  TypeRange workgroupAttributions,
1380  TypeRange privateAttributions,
1381  ArrayRef<NamedAttribute> attrs) {
1382  OpBuilder::InsertionGuard g(builder);
1383 
1385  builder.getStringAttr(name));
1386  result.addAttribute(getFunctionTypeAttrName(result.name),
1387  TypeAttr::get(type));
1388  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1389  builder.getI64IntegerAttr(workgroupAttributions.size()));
1390  result.addAttributes(attrs);
1391  Region *body = result.addRegion();
1392  Block *entryBlock = builder.createBlock(body);
1393 
1394  // TODO: Allow passing in proper locations here.
1395  for (Type argTy : type.getInputs())
1396  entryBlock->addArgument(argTy, result.location);
1397  for (Type argTy : workgroupAttributions)
1398  entryBlock->addArgument(argTy, result.location);
1399  for (Type argTy : privateAttributions)
1400  entryBlock->addArgument(argTy, result.location);
1401 }
1402 
1403 /// Parses a GPU function memory attribution.
1404 ///
1405 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1406 /// (`private` `(` ssa-id-and-type-list `)`)?
1407 ///
1408 /// Note that this function parses only one of the two similar parts, with the
1409 /// keyword provided as argument.
1410 static ParseResult
1411 parseAttributions(OpAsmParser &parser, StringRef keyword,
1413  Attribute &attributionAttrs) {
1414  // If we could not parse the keyword, just assume empty list and succeed.
1415  if (failed(parser.parseOptionalKeyword(keyword)))
1416  return success();
1417 
1418  size_t existingArgs = args.size();
1419  ParseResult result =
1421  /*allowType=*/true, /*allowAttrs=*/true);
1422  if (failed(result))
1423  return result;
1424 
1425  bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1426  [](const OpAsmParser::Argument &arg) -> bool {
1427  return arg.attrs && !arg.attrs.empty();
1428  });
1429  if (!hadAttrs) {
1430  attributionAttrs = nullptr;
1431  return result;
1432  }
1433 
1434  Builder &builder = parser.getBuilder();
1435  SmallVector<Attribute> attributionAttrsVec;
1436  for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1437  if (!argument.attrs)
1438  attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1439  else
1440  attributionAttrsVec.push_back(argument.attrs);
1441  }
1442  attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1443  return result;
1444 }
1445 
1446 /// Parses a GPU function.
1447 ///
1448 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1449 /// (`->` function-result-list)? memory-attribution `kernel`?
1450 /// function-attributes? region
1451 ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
1453  SmallVector<DictionaryAttr> resultAttrs;
1454  SmallVector<Type> resultTypes;
1455  bool isVariadic;
1456 
1457  // Parse the function name.
1458  StringAttr nameAttr;
1459  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1460  result.attributes))
1461  return failure();
1462 
1463  auto signatureLocation = parser.getCurrentLocation();
1465  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1466  resultAttrs)))
1467  return failure();
1468 
1469  if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1470  return parser.emitError(signatureLocation)
1471  << "gpu.func requires named arguments";
1472 
1473  // Construct the function type. More types will be added to the region, but
1474  // not to the function type.
1475  Builder &builder = parser.getBuilder();
1476 
1477  SmallVector<Type> argTypes;
1478  for (auto &arg : entryArgs)
1479  argTypes.push_back(arg.type);
1480  auto type = builder.getFunctionType(argTypes, resultTypes);
1481  result.addAttribute(getFunctionTypeAttrName(result.name),
1482  TypeAttr::get(type));
1483 
1485  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1486  getResAttrsAttrName(result.name));
1487 
1488  Attribute workgroupAttributionAttrs;
1489  // Parse workgroup memory attributions.
1490  if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1491  entryArgs, workgroupAttributionAttrs)))
1492  return failure();
1493 
1494  // Store the number of operands we just parsed as the number of workgroup
1495  // memory attributions.
1496  unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1497  result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1498  builder.getI64IntegerAttr(numWorkgroupAttrs));
1499  if (workgroupAttributionAttrs)
1500  result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1501  workgroupAttributionAttrs);
1502 
1503  Attribute privateAttributionAttrs;
1504  // Parse private memory attributions.
1505  if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1506  entryArgs, privateAttributionAttrs)))
1507  return failure();
1508  if (privateAttributionAttrs)
1509  result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1510  privateAttributionAttrs);
1511 
1512  // Parse the kernel attribute if present.
1513  if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1514  result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1515  builder.getUnitAttr());
1516 
1517  // Parse attributes.
1518  if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1519  return failure();
1520 
1521  // Parse the region. If no argument names were provided, take all names
1522  // (including those of attributions) from the entry block.
1523  auto *body = result.addRegion();
1524  return parser.parseRegion(*body, entryArgs);
1525 }
1526 
1527 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
1528  ArrayRef<BlockArgument> values,
1529  ArrayAttr attributes) {
1530  if (values.empty())
1531  return;
1532 
1533  p << ' ' << keyword << '(';
1534  llvm::interleaveComma(
1535  llvm::enumerate(values), p, [&p, attributes](auto pair) {
1536  BlockArgument v = pair.value();
1537  p << v << " : " << v.getType();
1538 
1539  size_t attributionIndex = pair.index();
1540  DictionaryAttr attrs;
1541  if (attributes && attributionIndex < attributes.size())
1542  attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1543  if (attrs)
1544  p.printOptionalAttrDict(attrs.getValue());
1545  });
1546  p << ')';
1547 }
1548 
1549 void GPUFuncOp::print(OpAsmPrinter &p) {
1550  p << ' ';
1551  p.printSymbolName(getName());
1552 
1553  FunctionType type = getFunctionType();
1554  function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1555  /*isVariadic=*/false,
1556  type.getResults());
1557 
1558  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1559  getWorkgroupAttribAttrs().value_or(nullptr));
1560  printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1561  getPrivateAttribAttrs().value_or(nullptr));
1562  if (isKernel())
1563  p << ' ' << getKernelKeyword();
1564 
1566  p, *this,
1567  {getNumWorkgroupAttributionsAttrName(),
1568  GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1569  getArgAttrsAttrName(), getResAttrsAttrName(),
1570  getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1571  p << ' ';
1572  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1573 }
1574 
1575 static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1576  StringAttr attrName) {
1577  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1578  if (!allAttrs || index >= allAttrs.size())
1579  return DictionaryAttr();
1580  return llvm::cast<DictionaryAttr>(allAttrs[index]);
1581 }
1582 
1583 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1584  return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1585 }
1586 
1587 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1588  return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1589 }
1590 
1591 static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1592  DictionaryAttr value, StringAttr attrName) {
1593  MLIRContext *ctx = op.getContext();
1594  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1595  SmallVector<Attribute> elements;
1596  if (allAttrs)
1597  elements.append(allAttrs.begin(), allAttrs.end());
1598  while (elements.size() <= index)
1599  elements.push_back(DictionaryAttr::get(ctx));
1600  if (!value)
1601  elements[index] = DictionaryAttr::get(ctx);
1602  else
1603  elements[index] = value;
1604  ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1605  op->setAttr(attrName, newValue);
1606 }
1607 
1608 void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1609  DictionaryAttr value) {
1610  setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1611 }
1612 
1613 void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1614  DictionaryAttr value) {
1615  setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1616 }
1617 
1618 static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1619  StringAttr name, StringAttr attrsName) {
1620  DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1621  if (!dict)
1622  return Attribute();
1623  return dict.get(name);
1624 }
1625 
1626 Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1627  StringAttr name) {
1628  assert(index < getNumWorkgroupAttributions() &&
1629  "index must map to a workgroup attribution");
1630  return getAttributionAttr(*this, index, name,
1631  getWorkgroupAttribAttrsAttrName());
1632 }
1633 
1634 Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1635  StringAttr name) {
1636  assert(index < getNumPrivateAttributions() &&
1637  "index must map to a private attribution");
1638  return getAttributionAttr(*this, index, name,
1639  getPrivateAttribAttrsAttrName());
1640 }
1641 
1642 static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1643  Attribute value, StringAttr attrsName) {
1644  MLIRContext *ctx = op.getContext();
1646  DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1647  if (oldDict)
1648  elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1649 
1650  bool found = false;
1651  bool mustSort = true;
1652  for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1653  if (elems[i].getName() == name) {
1654  found = true;
1655  if (!value) {
1656  std::swap(elems[i], elems[elems.size() - 1]);
1657  elems.pop_back();
1658  } else {
1659  mustSort = false;
1660  elems[i] = NamedAttribute(elems[i].getName(), value);
1661  }
1662  break;
1663  }
1664  }
1665  if (!found) {
1666  if (!value)
1667  return;
1668  elems.emplace_back(name, value);
1669  }
1670  if (mustSort) {
1671  DictionaryAttr::sortInPlace(elems);
1672  }
1673  auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1674  setAttributionAttrs(op, index, newDict, attrsName);
1675 }
1676 
1677 void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1678  Attribute value) {
1679  assert(index < getNumWorkgroupAttributions() &&
1680  "index must map to a workgroup attribution");
1681  setAttributionAttr(*this, index, name, value,
1682  getWorkgroupAttribAttrsAttrName());
1683 }
1684 
1685 void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1686  Attribute value) {
1687  assert(index < getNumPrivateAttributions() &&
1688  "index must map to a private attribution");
1689  setAttributionAttr(*this, index, name, value,
1690  getPrivateAttribAttrsAttrName());
1691 }
1692 
1693 LogicalResult GPUFuncOp::verifyType() {
1694  if (isKernel() && getFunctionType().getNumResults() != 0)
1695  return emitOpError() << "expected void return type for kernel function";
1696 
1697  return success();
1698 }
1699 
1700 /// Verifies the body of the function.
1701 LogicalResult GPUFuncOp::verifyBody() {
1702  if (empty())
1703  return emitOpError() << "expected body with at least one block";
1704  unsigned numFuncArguments = getNumArguments();
1705  unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1706  unsigned numBlockArguments = front().getNumArguments();
1707  if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1708  return emitOpError() << "expected at least "
1709  << numFuncArguments + numWorkgroupAttributions
1710  << " arguments to body region";
1711 
1712  ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1713  for (unsigned i = 0; i < numFuncArguments; ++i) {
1714  Type blockArgType = front().getArgument(i).getType();
1715  if (funcArgTypes[i] != blockArgType)
1716  return emitOpError() << "expected body region argument #" << i
1717  << " to be of type " << funcArgTypes[i] << ", got "
1718  << blockArgType;
1719  }
1720 
1721  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1722  GPUDialect::getWorkgroupAddressSpace())) ||
1723  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1724  GPUDialect::getPrivateAddressSpace())))
1725  return failure();
1726 
1727  return success();
1728 }
1729 
1730 //===----------------------------------------------------------------------===//
1731 // ReturnOp
1732 //===----------------------------------------------------------------------===//
1733 
1734 LogicalResult gpu::ReturnOp::verify() {
1735  GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1736 
1737  FunctionType funType = function.getFunctionType();
1738 
1739  if (funType.getNumResults() != getOperands().size())
1740  return emitOpError()
1741  .append("expected ", funType.getNumResults(), " result operands")
1742  .attachNote(function.getLoc())
1743  .append("return type declared here");
1744 
1745  for (const auto &pair : llvm::enumerate(
1746  llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1747  auto [type, operand] = pair.value();
1748  if (type != operand.getType())
1749  return emitOpError() << "unexpected type `" << operand.getType()
1750  << "' for operand #" << pair.index();
1751  }
1752  return success();
1753 }
1754 
1755 //===----------------------------------------------------------------------===//
1756 // GPUModuleOp
1757 //===----------------------------------------------------------------------===//
1758 
1759 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1760  StringRef name, ArrayAttr targets,
1761  Attribute offloadingHandler) {
1762  result.addRegion()->emplaceBlock();
1763  Properties &props = result.getOrAddProperties<Properties>();
1764  if (targets)
1765  props.targets = targets;
1766  props.setSymName(builder.getStringAttr(name));
1767  props.offloadingHandler = offloadingHandler;
1768 }
1769 
1770 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1771  StringRef name, ArrayRef<Attribute> targets,
1772  Attribute offloadingHandler) {
1773  build(builder, result, name,
1774  targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1775  offloadingHandler);
1776 }
1777 
1778 bool GPUModuleOp::hasTarget(Attribute target) {
1779  if (ArrayAttr targets = getTargetsAttr())
1780  return llvm::count(targets.getValue(), target);
1781  return false;
1782 }
1783 
1784 void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1785  ArrayAttr &targetsAttr = getProperties().targets;
1786  SmallVector<Attribute> targetsVector(targets);
1787  targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1788 }
1789 
1790 //===----------------------------------------------------------------------===//
1791 // GPUBinaryOp
1792 //===----------------------------------------------------------------------===//
1793 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1794  Attribute offloadingHandler, ArrayAttr objects) {
1795  auto &properties = result.getOrAddProperties<Properties>();
1796  result.attributes.push_back(builder.getNamedAttr(
1797  SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1798  properties.objects = objects;
1799  if (offloadingHandler)
1800  properties.offloadingHandler = offloadingHandler;
1801  else
1802  properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
1803 }
1804 
1805 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1806  Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1807  build(builder, result, name, offloadingHandler,
1808  objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
1809 }
1810 
1811 static ParseResult parseOffloadingHandler(OpAsmParser &parser,
1812  Attribute &offloadingHandler) {
1813  if (succeeded(parser.parseOptionalLess())) {
1814  if (parser.parseAttribute(offloadingHandler))
1815  return failure();
1816  if (parser.parseGreater())
1817  return failure();
1818  }
1819  if (!offloadingHandler)
1820  offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
1821  return success();
1822 }
1823 
1825  Attribute offloadingHandler) {
1826  if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
1827  printer << '<' << offloadingHandler << '>';
1828 }
1829 
1830 //===----------------------------------------------------------------------===//
1831 // GPUMemcpyOp
1832 //===----------------------------------------------------------------------===//
1833 
1834 LogicalResult MemcpyOp::verify() {
1835  auto srcType = getSrc().getType();
1836  auto dstType = getDst().getType();
1837 
1838  if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1839  return emitOpError("arguments have incompatible element type");
1840 
1841  if (failed(verifyCompatibleShape(srcType, dstType)))
1842  return emitOpError("arguments have incompatible shape");
1843 
1844  return success();
1845 }
1846 
1847 namespace {
1848 
1849 /// Erases a common case of copy ops where a destination value is used only by
1850 /// the copy op, alloc and dealloc ops.
1851 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1852  using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1853 
1854  LogicalResult matchAndRewrite(MemcpyOp op,
1855  PatternRewriter &rewriter) const override {
1856  Value dest = op.getDst();
1857  Operation *destDefOp = dest.getDefiningOp();
1858  // `dest` must be defined by an op having Allocate memory effect in order to
1859  // perform the folding.
1860  if (!destDefOp ||
1861  !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1862  return failure();
1863  // We can erase `op` iff `dest` has no other use apart from its
1864  // use by `op` and dealloc ops.
1865  if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1866  return user != op &&
1867  !hasSingleEffect<MemoryEffects::Free>(user, dest);
1868  }))
1869  return failure();
1870  // We can perform the folding if and only if op has a single async
1871  // dependency and produces an async token as result, or if it does not have
1872  // any async dependency and does not produce any async token result.
1873  if (op.getAsyncDependencies().size() > 1 ||
1874  ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1875  (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1876  return failure();
1877  rewriter.replaceOp(op, op.getAsyncDependencies());
1878  return success();
1879  }
1880 };
1881 
1882 } // end anonymous namespace
1883 
1884 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1885  MLIRContext *context) {
1886  results.add<EraseTrivialCopyOp>(context);
1887 }
1888 
1889 //===----------------------------------------------------------------------===//
1890 // GPU_SubgroupMmaLoadMatrixOp
1891 //===----------------------------------------------------------------------===//
1892 
1893 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1894  auto srcType = getSrcMemref().getType();
1895  auto resType = getRes().getType();
1896  auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
1897  auto operand = resMatrixType.getOperand();
1898  auto srcMemrefType = llvm::cast<MemRefType>(srcType);
1899 
1900  if (!isLastMemrefDimUnitStride(srcMemrefType))
1901  return emitError(
1902  "expected source memref most minor dim must have unit stride");
1903 
1904  if (operand != "AOp" && operand != "BOp" && operand != "COp")
1905  return emitError("only AOp, BOp and COp can be loaded");
1906 
1907  return success();
1908 }
1909 
1910 //===----------------------------------------------------------------------===//
1911 // GPU_SubgroupMmaStoreMatrixOp
1912 //===----------------------------------------------------------------------===//
1913 
1914 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1915  auto srcType = getSrc().getType();
1916  auto dstType = getDstMemref().getType();
1917  auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
1918  auto dstMemrefType = llvm::cast<MemRefType>(dstType);
1919 
1920  if (!isLastMemrefDimUnitStride(dstMemrefType))
1921  return emitError(
1922  "expected destination memref most minor dim must have unit stride");
1923 
1924  if (srcMatrixType.getOperand() != "COp")
1925  return emitError(
1926  "expected the operand matrix being stored to have 'COp' operand type");
1927 
1928  return success();
1929 }
1930 
1931 //===----------------------------------------------------------------------===//
1932 // GPU_SubgroupMmaComputeOp
1933 //===----------------------------------------------------------------------===//
1934 
1935 LogicalResult SubgroupMmaComputeOp::verify() {
1936  enum OperandMap { A, B, C };
1937  SmallVector<MMAMatrixType, 3> opTypes;
1938  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
1939  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
1940  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
1941 
1942  if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
1943  opTypes[C].getOperand() != "COp")
1944  return emitError("operands must be in the order AOp, BOp, COp");
1945 
1946  ArrayRef<int64_t> aShape, bShape, cShape;
1947  aShape = opTypes[A].getShape();
1948  bShape = opTypes[B].getShape();
1949  cShape = opTypes[C].getShape();
1950 
1951  if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1952  bShape[1] != cShape[1])
1953  return emitError("operand shapes do not satisfy matmul constraints");
1954 
1955  return success();
1956 }
1957 
1958 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
1959  SmallVectorImpl<::mlir::OpFoldResult> &results) {
1960  return memref::foldMemRefCast(*this);
1961 }
1962 
1963 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
1964  SmallVectorImpl<::mlir::OpFoldResult> &results) {
1965  return memref::foldMemRefCast(*this);
1966 }
1967 
1968 //===----------------------------------------------------------------------===//
1969 // GPU_WaitOp
1970 //===----------------------------------------------------------------------===//
1971 
1972 namespace {
1973 
1974 /// Remove gpu.wait op use of gpu.wait op def without async dependencies.
1975 /// %t = gpu.wait async [] // No async dependencies.
1976 /// ... gpu.wait ... [%t, ...] // %t can be removed.
1977 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
1978 public:
1979  using OpRewritePattern::OpRewritePattern;
1980 
1981  LogicalResult matchAndRewrite(WaitOp op,
1982  PatternRewriter &rewriter) const final {
1983  auto predicate = [](Value value) {
1984  auto waitOp = value.getDefiningOp<WaitOp>();
1985  return waitOp && waitOp->getNumOperands() == 0;
1986  };
1987  if (llvm::none_of(op.getAsyncDependencies(), predicate))
1988  return failure();
1989  SmallVector<Value> validOperands;
1990  for (Value operand : op->getOperands()) {
1991  if (predicate(operand))
1992  continue;
1993  validOperands.push_back(operand);
1994  }
1995  rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
1996  return success();
1997  }
1998 };
1999 
2000 /// Simplify trivial gpu.wait ops for the following patterns.
2001 /// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2002 /// dependencies).
2003 /// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2004 /// %t0.
2005 /// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2006 /// dependencies nor return any token.
2007 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2008 public:
2009  using OpRewritePattern::OpRewritePattern;
2010 
2011  LogicalResult matchAndRewrite(WaitOp op,
2012  PatternRewriter &rewriter) const final {
2013  // Erase gpu.wait ops that neither have any async dependencies nor return
2014  // any async token.
2015  if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2016  rewriter.eraseOp(op);
2017  return success();
2018  }
2019  // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2020  if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2021  op.getAsyncToken()) {
2022  rewriter.replaceOp(op, op.getAsyncDependencies());
2023  return success();
2024  }
2025  // Erase %t = gpu.wait async ... ops, where %t has no uses.
2026  if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2027  rewriter.eraseOp(op);
2028  return success();
2029  }
2030  return failure();
2031  }
2032 };
2033 
2034 } // end anonymous namespace
2035 
2036 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2037  MLIRContext *context) {
2038  results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2039 }
2040 
2041 //===----------------------------------------------------------------------===//
2042 // GPU_AllocOp
2043 //===----------------------------------------------------------------------===//
2044 
2045 LogicalResult AllocOp::verify() {
2046  auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2047 
2048  if (static_cast<int64_t>(getDynamicSizes().size()) !=
2049  memRefType.getNumDynamicDims())
2050  return emitOpError("dimension operand count does not equal memref "
2051  "dynamic dimension count");
2052 
2053  unsigned numSymbols = 0;
2054  if (!memRefType.getLayout().isIdentity())
2055  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2056  if (getSymbolOperands().size() != numSymbols) {
2057  return emitOpError(
2058  "symbol operand count does not equal memref symbol count");
2059  }
2060 
2061  return success();
2062 }
2063 
2064 namespace {
2065 
2066 /// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2067 /// `memref::AllocOp`.
2068 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2069  using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2070 
2071  LogicalResult matchAndRewrite(memref::DimOp dimOp,
2072  PatternRewriter &rewriter) const override {
2073  std::optional<int64_t> index = dimOp.getConstantIndex();
2074  if (!index)
2075  return failure();
2076 
2077  auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2078  if (!memrefType || !memrefType.isDynamicDim(index.value()))
2079  return failure();
2080 
2081  auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2082  if (!alloc)
2083  return failure();
2084 
2085  Value substituteOp = *(alloc.getDynamicSizes().begin() +
2086  memrefType.getDynamicDimIndex(index.value()));
2087  rewriter.replaceOp(dimOp, substituteOp);
2088  return success();
2089  }
2090 };
2091 
2092 } // namespace
2093 
2094 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2095  MLIRContext *context) {
2096  results.add<SimplifyDimOfAllocOp>(context);
2097 }
2098 
2099 //===----------------------------------------------------------------------===//
2100 // GPU object attribute
2101 //===----------------------------------------------------------------------===//
2102 
2103 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2104  Attribute target, CompilationTarget format,
2105  StringAttr object, DictionaryAttr properties,
2106  KernelTableAttr kernels) {
2107  if (!target)
2108  return emitError() << "the target attribute cannot be null";
2109  if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2110  return success();
2111  return emitError() << "the target attribute must implement or promise the "
2112  "`gpu::TargetAttrInterface`";
2113 }
2114 
2115 namespace {
2116 LogicalResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2117  StringAttr &object) {
2118  std::optional<CompilationTarget> formatResult;
2119  StringRef enumKeyword;
2120  auto loc = odsParser.getCurrentLocation();
2121  if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2122  formatResult = CompilationTarget::Fatbin;
2123  if (!formatResult &&
2124  (formatResult =
2125  gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2126  odsParser.parseEqual())
2127  return odsParser.emitError(loc, "expected an equal sign");
2128  if (!formatResult)
2129  return odsParser.emitError(loc, "expected keyword for GPU object format");
2130  FailureOr<StringAttr> objectResult =
2131  FieldParser<StringAttr>::parse(odsParser);
2132  if (failed(objectResult))
2133  return odsParser.emitError(odsParser.getCurrentLocation(),
2134  "failed to parse GPU_ObjectAttr parameter "
2135  "'object' which is to be a `StringAttr`");
2136  format = *formatResult;
2137  object = *objectResult;
2138  return success();
2139 }
2140 
2141 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2142  StringAttr object) {
2143  if (format != CompilationTarget::Fatbin)
2144  odsParser << stringifyEnum(format) << " = ";
2145  odsParser << object;
2146 }
2147 } // namespace
2148 
2149 //===----------------------------------------------------------------------===//
2150 // GPU select object attribute
2151 //===----------------------------------------------------------------------===//
2152 
2153 LogicalResult
2154 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2155  Attribute target) {
2156  // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2157  if (target) {
2158  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2159  if (intAttr.getInt() < 0) {
2160  return emitError() << "the object index must be positive";
2161  }
2162  } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2163  return emitError()
2164  << "the target attribute must be a GPU Target attribute";
2165  }
2166  }
2167  return success();
2168 }
2169 
2170 //===----------------------------------------------------------------------===//
2171 // DynamicSharedMemoryOp
2172 //===----------------------------------------------------------------------===//
2173 
2174 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2175  if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2176  return emitOpError() << "must be inside an op with symbol table";
2177 
2178  MemRefType memrefType = getResultMemref().getType();
2179  // Check address space
2180  if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2181  return emitOpError() << "address space must be "
2182  << gpu::AddressSpaceAttr::getMnemonic() << "<"
2183  << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2184  }
2185  if (memrefType.hasStaticShape()) {
2186  return emitOpError() << "result memref type must be memref<?xi8, "
2187  "#gpu.address_space<workgroup>>";
2188  }
2189  return success();
2190 }
2191 
2192 //===----------------------------------------------------------------------===//
2193 // GPU KernelMetadataAttr
2194 //===----------------------------------------------------------------------===//
2195 
2196 KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2197  DictionaryAttr metadata) {
2198  assert(kernel && "invalid kernel");
2199  return get(kernel.getNameAttr(), kernel.getFunctionType(),
2200  kernel.getAllArgAttrs(), metadata);
2201 }
2202 
2203 KernelMetadataAttr
2204 KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2205  FunctionOpInterface kernel,
2206  DictionaryAttr metadata) {
2207  assert(kernel && "invalid kernel");
2208  return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2209  kernel.getAllArgAttrs(), metadata);
2210 }
2211 
2212 KernelMetadataAttr
2213 KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2214  if (attrs.empty())
2215  return *this;
2216  NamedAttrList attrList;
2217  if (DictionaryAttr dict = getMetadata())
2218  attrList.append(dict);
2219  attrList.append(attrs);
2220  return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2221  attrList.getDictionary(getContext()));
2222 }
2223 
2224 LogicalResult
2225 KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2226  StringAttr name, Type functionType,
2227  ArrayAttr argAttrs, DictionaryAttr metadata) {
2228  if (name.empty())
2229  return emitError() << "the kernel name can't be empty";
2230  if (argAttrs) {
2231  if (llvm::any_of(argAttrs, [](Attribute attr) {
2232  return !llvm::isa<DictionaryAttr>(attr);
2233  }))
2234  return emitError()
2235  << "all attributes in the array must be a dictionary attribute";
2236  }
2237  return success();
2238 }
2239 
2240 //===----------------------------------------------------------------------===//
2241 // GPU KernelTableAttr
2242 //===----------------------------------------------------------------------===//
2243 
2244 KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2245  ArrayRef<KernelMetadataAttr> kernels,
2246  bool isSorted) {
2247  // Note that `is_sorted` is always only invoked once even with assertions ON.
2248  assert((!isSorted || llvm::is_sorted(kernels)) &&
2249  "expected a sorted kernel array");
2250  // Immediately return the attribute if the array is sorted.
2251  if (isSorted || llvm::is_sorted(kernels))
2252  return Base::get(context, kernels);
2253  // Sort the array.
2254  SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2255  llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2256  return Base::get(context, kernelsTmp);
2257 }
2258 
2259 KernelTableAttr KernelTableAttr::getChecked(
2260  function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2261  ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2262  // Note that `is_sorted` is always only invoked once even with assertions ON.
2263  assert((!isSorted || llvm::is_sorted(kernels)) &&
2264  "expected a sorted kernel array");
2265  // Immediately return the attribute if the array is sorted.
2266  if (isSorted || llvm::is_sorted(kernels))
2267  return Base::getChecked(emitError, context, kernels);
2268  // Sort the array.
2269  SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2270  llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2271  return Base::getChecked(emitError, context, kernelsTmp);
2272 }
2273 
2274 LogicalResult
2275 KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2276  ArrayRef<KernelMetadataAttr> kernels) {
2277  if (kernels.size() < 2)
2278  return success();
2279  // Check that the kernels are uniquely named.
2280  if (std::adjacent_find(kernels.begin(), kernels.end(),
2281  [](KernelMetadataAttr l, KernelMetadataAttr r) {
2282  return l.getName() == r.getName();
2283  }) != kernels.end()) {
2284  return emitError() << "expected all kernels to be uniquely named";
2285  }
2286  return success();
2287 }
2288 
2289 KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2290  auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2291  return found ? *iterator : KernelMetadataAttr();
2292 }
2293 
2294 KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2295  auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2296  return found ? *iterator : KernelMetadataAttr();
2297 }
2298 
2299 //===----------------------------------------------------------------------===//
2300 // GPU target options
2301 //===----------------------------------------------------------------------===//
2302 
2303 TargetOptions::TargetOptions(
2304  StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2305  StringRef cmdOptions, CompilationTarget compilationTarget,
2306  function_ref<SymbolTable *()> getSymbolTableCallback)
2307  : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
2308  cmdOptions, compilationTarget, getSymbolTableCallback) {}
2309 
2310 TargetOptions::TargetOptions(
2311  TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2312  StringRef cmdOptions, CompilationTarget compilationTarget,
2313  function_ref<SymbolTable *()> getSymbolTableCallback)
2314  : toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
2315  cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
2316  getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
2317 
2318 TypeID TargetOptions::getTypeID() const { return typeID; }
2319 
2320 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2321 
2322 ArrayRef<std::string> TargetOptions::getLinkFiles() const { return linkFiles; }
2323 
2324 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2325 
2326 SymbolTable *TargetOptions::getSymbolTable() const {
2327  return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2328 }
2329 
2330 CompilationTarget TargetOptions::getCompilationTarget() const {
2331  return compilationTarget;
2332 }
2333 
2334 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2335  return CompilationTarget::Fatbin;
2336 }
2337 
2338 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2339 TargetOptions::tokenizeCmdOptions() const {
2340  std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2341  llvm::StringSaver stringSaver(options.first);
2342  StringRef opts = cmdOptions;
2343  // For a correct tokenization of the command line options `opts` must be
2344  // unquoted, otherwise the tokenization function returns a single string: the
2345  // unquoted `cmdOptions` -which is not the desired behavior.
2346  // Remove any quotes if they are at the beginning and end of the string:
2347  if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2348  opts.consume_front("\""), opts.consume_back("\"");
2349  if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2350  opts.consume_front("'"), opts.consume_back("'");
2351 #ifdef _WIN32
2352  llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2353  /*MarkEOLs=*/false);
2354 #else
2355  llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2356  /*MarkEOLs=*/false);
2357 #endif // _WIN32
2358  return options;
2359 }
2360 
2362 
2363 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2364 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2365 
2366 #define GET_ATTRDEF_CLASSES
2367 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2368 
2369 #define GET_OP_CLASSES
2370 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2371 
2372 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
Definition: GPUDialect.cpp:421
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
Definition: GPUDialect.cpp:585
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
Definition: GPUDialect.cpp:437
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
Definition: GPUDialect.cpp:896
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)
Prints a GPU function memory attribution.
Definition: GPUDialect.cpp:471
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
Definition: GPUDialect.cpp:563
static std::string getSparseHandleKeyword(SparseHandleKind kind)
Definition: GPUDialect.cpp:221
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
Definition: GPUDialect.cpp:311
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
Definition: GPUDialect.cpp:598
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
Definition: GPUDialect.cpp:460
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
Definition: GPUDialect.cpp:509
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
Definition: GPUDialect.cpp:847
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
Definition: GPUDialect.cpp:483
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MINUI(lhs, rhs)
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:263
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
UnitAttr getUnitAttr()
Definition: Builders.cpp:126
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:228
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:191
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:108
IntegerType getI32Type()
Definition: Builders.cpp:95
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:140
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:96
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:281
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:285
IndexType getIndexType()
Definition: Builders.cpp:83
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition: Builders.cpp:132
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:122
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:103
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:313
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:49
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:221
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:353
This class helps build Operations.
Definition: Builders.h:212
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:436
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:449
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
Definition: Operation.cpp:256
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:82
bool isIndex() const
Definition: Types.cpp:59
bool isF32() const
Definition: Types.cpp:54
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:94
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:61
bool isF16() const
Definition: Types.cpp:52
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static ConcreteT get(MLIRContext *ctx, Args &&...args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Definition: GPUDialect.cpp:136
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:121
Type getElementType() const
Get elementType of a single element.
Definition: GPUDialect.cpp:140
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
Definition: GPUDialect.cpp:144
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
Definition: GPUDialect.cpp:151
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Definition: GPUDialect.cpp:142
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
Definition: GPUDialect.cpp:127
unsigned getNumDims() const
Get number of dims.
Definition: GPUDialect.cpp:134
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
Definition: GPUDialect.cpp:662
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:473
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:39