MLIR  21.0.0git
GPUDialect.cpp
Go to the documentation of this file.
1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU kernel-related dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/SymbolTable.h"
29 #include "mlir/IR/TypeUtilities.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/TypeSwitch.h"
36 #include "llvm/Support/CommandLine.h"
37 #include "llvm/Support/ErrorHandling.h"
38 #include "llvm/Support/StringSaver.h"
39 #include <cassert>
40 #include <numeric>
41 
42 using namespace mlir;
43 using namespace mlir::gpu;
44 
45 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
46 
47 //===----------------------------------------------------------------------===//
48 // GPU Device Mapping Attributes
49 //===----------------------------------------------------------------------===//
50 
51 int64_t GPUBlockMappingAttr::getMappingId() const {
52  return static_cast<int64_t>(getBlock());
53 }
54 
55 bool GPUBlockMappingAttr::isLinearMapping() const {
56  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
57 }
58 
59 int64_t GPUBlockMappingAttr::getRelativeIndex() const {
60  return isLinearMapping()
61  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
62  : getMappingId();
63 }
64 
65 int64_t GPUWarpgroupMappingAttr::getMappingId() const {
66  return static_cast<int64_t>(getWarpgroup());
67 }
68 
69 bool GPUWarpgroupMappingAttr::isLinearMapping() const {
70  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
71 }
72 
73 int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
74  return isLinearMapping()
75  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
76  : getMappingId();
77 }
78 
79 int64_t GPUWarpMappingAttr::getMappingId() const {
80  return static_cast<int64_t>(getWarp());
81 }
82 
83 bool GPUWarpMappingAttr::isLinearMapping() const {
84  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
85 }
86 
87 int64_t GPUWarpMappingAttr::getRelativeIndex() const {
88  return isLinearMapping()
89  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
90  : getMappingId();
91 }
92 
93 int64_t GPUThreadMappingAttr::getMappingId() const {
94  return static_cast<int64_t>(getThread());
95 }
96 
97 bool GPUThreadMappingAttr::isLinearMapping() const {
98  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
99 }
100 
101 int64_t GPUThreadMappingAttr::getRelativeIndex() const {
102  return isLinearMapping()
103  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
104  : getMappingId();
105 }
106 
107 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
108  return static_cast<int64_t>(getAddressSpace());
109 }
110 
111 bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
112  llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
113 }
114 
115 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
116  llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
117 }
118 
119 //===----------------------------------------------------------------------===//
120 // MMAMatrixType
121 //===----------------------------------------------------------------------===//
122 
124  StringRef operand) {
125  return Base::get(elementType.getContext(), shape, elementType, operand);
126 }
127 
130  ArrayRef<int64_t> shape, Type elementType,
131  StringRef operand) {
132  return Base::getChecked(emitError, elementType.getContext(), shape,
133  elementType, operand);
134 }
135 
136 unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
137 
139  return getImpl()->getShape();
140 }
141 
142 Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
143 
144 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
145 
147  return elementType.isF16() || elementType.isF32() ||
148  elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
149  elementType.isInteger(32);
150 }
151 
152 LogicalResult
154  ArrayRef<int64_t> shape, Type elementType,
155  StringRef operand) {
156  if (operand != "AOp" && operand != "BOp" && operand != "COp")
157  return emitError() << "operand expected to be one of AOp, BOp or COp";
158 
159  if (shape.size() != 2)
160  return emitError() << "MMAMatrixType must have exactly two dimensions";
161 
162  if (!MMAMatrixType::isValidElementType(elementType))
163  return emitError()
164  << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
165 
166  return success();
167 }
168 
169 //===----------------------------------------------------------------------===//
170 // GPUDialect
171 //===----------------------------------------------------------------------===//
172 
173 bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
174  if (!memorySpace)
175  return false;
176  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
177  return gpuAttr.getValue() == getWorkgroupAddressSpace();
178  return false;
179 }
180 
181 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
182  Attribute memorySpace = type.getMemorySpace();
183  return isWorkgroupMemoryAddressSpace(memorySpace);
184 }
185 
186 bool GPUDialect::isKernel(Operation *op) {
187  UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
188  return static_cast<bool>(isKernelAttr);
189 }
190 
191 namespace {
192 /// This class defines the interface for handling inlining with gpu
193 /// operations.
194 struct GPUInlinerInterface : public DialectInlinerInterface {
196 
197  /// All gpu dialect ops can be inlined.
198  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
199  return true;
200  }
201 };
202 } // namespace
203 
204 void GPUDialect::initialize() {
205  addTypes<AsyncTokenType>();
206  addTypes<MMAMatrixType>();
207  addTypes<SparseDnTensorHandleType>();
208  addTypes<SparseSpMatHandleType>();
209  addTypes<SparseSpGEMMOpHandleType>();
210  addOperations<
211 #define GET_OP_LIST
212 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
213  >();
214  addAttributes<
215 #define GET_ATTRDEF_LIST
216 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
217  >();
218  addInterfaces<GPUInlinerInterface>();
219  declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
220  TerminatorOp>();
221  declarePromisedInterfaces<
222  ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,
223  ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,
224  SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();
225 }
226 
228  switch (kind) {
230  return "sparse.dntensor_handle";
232  return "sparse.spmat_handle";
234  return "sparse.spgemmop_handle";
235  }
236  llvm_unreachable("unknown sparse handle kind");
237  return "";
238 }
239 
241  // Parse the main keyword for the type.
242  StringRef keyword;
243  if (parser.parseKeyword(&keyword))
244  return Type();
245  MLIRContext *context = getContext();
246 
247  // Handle 'async token' types.
248  if (keyword == "async.token")
249  return AsyncTokenType::get(context);
250 
251  if (keyword == "mma_matrix") {
252  SMLoc beginLoc = parser.getNameLoc();
253 
254  // Parse '<'.
255  if (parser.parseLess())
256  return nullptr;
257 
258  // Parse the size and elementType.
259  SmallVector<int64_t> shape;
260  Type elementType;
261  if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
262  parser.parseType(elementType))
263  return nullptr;
264 
265  // Parse ','
266  if (parser.parseComma())
267  return nullptr;
268 
269  // Parse operand.
270  std::string operand;
271  if (failed(parser.parseOptionalString(&operand)))
272  return nullptr;
273 
274  // Parse '>'.
275  if (parser.parseGreater())
276  return nullptr;
277 
279  parser.getEncodedSourceLoc(beginLoc)),
280  shape, elementType, operand);
281  }
282 
284  return SparseDnTensorHandleType::get(context);
286  return SparseSpMatHandleType::get(context);
288  return SparseSpGEMMOpHandleType::get(context);
289 
290  parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
291  return Type();
292 }
293 // TODO: print refined type here. Notice that should be corresponding to the
294 // parser
295 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
296  TypeSwitch<Type>(type)
297  .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
298  .Case<SparseDnTensorHandleType>([&](Type) {
300  })
301  .Case<SparseSpMatHandleType>(
303  .Case<SparseSpGEMMOpHandleType>([&](Type) {
305  })
306  .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
307  os << "mma_matrix<";
308  auto shape = fragTy.getShape();
309  for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
310  os << *dim << 'x';
311  os << shape.back() << 'x' << fragTy.getElementType();
312  os << ", \"" << fragTy.getOperand() << "\"" << '>';
313  })
314  .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
315 }
316 
317 static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
318  NamedAttribute attr) {
319  auto array = dyn_cast<DenseI32ArrayAttr>(attr.getValue());
320  if (!array)
321  return op->emitOpError(Twine(attr.getName()) +
322  " must be a dense i32 array");
323  if (array.size() != 3)
324  return op->emitOpError(Twine(attr.getName()) +
325  " must contain exactly 3 elements");
326  return success();
327 }
328 
329 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
330  NamedAttribute attr) {
331  if (attr.getName() == getKnownBlockSizeAttrHelper().getName())
332  return verifyKnownLaunchSizeAttr(op, attr);
333  if (attr.getName() == getKnownGridSizeAttrHelper().getName())
334  return verifyKnownLaunchSizeAttr(op, attr);
335  if (!llvm::isa<UnitAttr>(attr.getValue()) ||
336  attr.getName() != getContainerModuleAttrName())
337  return success();
338 
339  auto module = dyn_cast<ModuleOp>(op);
340  if (!module)
341  return op->emitError("expected '")
342  << getContainerModuleAttrName() << "' attribute to be attached to '"
343  << ModuleOp::getOperationName() << '\'';
344 
345  auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
346  // Ignore launches that are nested more or less deep than functions in the
347  // module we are currently checking.
348  if (!launchOp->getParentOp() ||
349  launchOp->getParentOp()->getParentOp() != module)
350  return success();
351 
352  // Ignore launch ops with missing attributes here. The errors will be
353  // reported by the verifiers of those ops.
354  if (!launchOp->getAttrOfType<SymbolRefAttr>(
355  LaunchFuncOp::getKernelAttrName(launchOp->getName())))
356  return success();
357 
358  // Check that `launch_func` refers to a well-formed GPU kernel container.
359  StringAttr kernelContainerName = launchOp.getKernelModuleName();
360  Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
361  if (!kernelContainer)
362  return launchOp.emitOpError()
363  << "kernel container '" << kernelContainerName.getValue()
364  << "' is undefined";
365 
366  // If the container is a GPU binary op return success.
367  if (isa<BinaryOp>(kernelContainer))
368  return success();
369 
370  auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
371  if (!kernelModule)
372  return launchOp.emitOpError()
373  << "kernel module '" << kernelContainerName.getValue()
374  << "' is undefined";
375 
376  // Check that `launch_func` refers to a well-formed kernel function.
377  Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
378  if (!kernelFunc)
379  return launchOp.emitOpError("kernel function '")
380  << launchOp.getKernel() << "' is undefined";
381  auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
382  if (!kernelConvertedFunction) {
383  InFlightDiagnostic diag = launchOp.emitOpError()
384  << "referenced kernel '" << launchOp.getKernel()
385  << "' is not a function";
386  diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
387  return diag;
388  }
389 
390  if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
391  GPUDialect::getKernelFuncAttrName()))
392  return launchOp.emitOpError("kernel function is missing the '")
393  << GPUDialect::getKernelFuncAttrName() << "' attribute";
394 
395  // TODO: If the kernel isn't a GPU function (which happens during separate
396  // compilation), do not check type correspondence as it would require the
397  // verifier to be aware of the type conversion.
398  auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
399  if (!kernelGPUFunction)
400  return success();
401 
402  unsigned actualNumArguments = launchOp.getNumKernelOperands();
403  unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
404  if (expectedNumArguments != actualNumArguments)
405  return launchOp.emitOpError("got ")
406  << actualNumArguments << " kernel operands but expected "
407  << expectedNumArguments;
408 
409  auto functionType = kernelGPUFunction.getFunctionType();
410  for (unsigned i = 0; i < expectedNumArguments; ++i) {
411  if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
412  return launchOp.emitOpError("type of function argument ")
413  << i << " does not match";
414  }
415  }
416 
417  return success();
418  });
419 
420  return walkResult.wasInterrupted() ? failure() : success();
421 }
422 
423 /// Parses an optional list of async operands with an optional leading keyword.
424 /// (`async`)? (`[` ssa-id-list `]`)?
425 ///
426 /// This method is used by the tablegen assembly format for async ops as well.
427 static ParseResult parseAsyncDependencies(
428  OpAsmParser &parser, Type &asyncTokenType,
430  auto loc = parser.getCurrentLocation();
431  if (succeeded(parser.parseOptionalKeyword("async"))) {
432  if (parser.getNumResults() == 0)
433  return parser.emitError(loc, "needs to be named when marked 'async'");
434  asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
435  }
436  return parser.parseOperandList(asyncDependencies,
438 }
439 
440 /// Prints optional async dependencies with its leading keyword.
441 /// (`async`)? (`[` ssa-id-list `]`)?
442 // Used by the tablegen assembly format for several async ops.
444  Type asyncTokenType,
445  OperandRange asyncDependencies) {
446  if (asyncTokenType)
447  printer << "async";
448  if (asyncDependencies.empty())
449  return;
450  if (asyncTokenType)
451  printer << ' ';
452  printer << '[';
453  llvm::interleaveComma(asyncDependencies, printer);
454  printer << ']';
455 }
456 
457 // GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
458 /// Parses a GPU function memory attribution.
459 ///
460 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
461 /// (`private` `(` ssa-id-and-type-list `)`)?
462 ///
463 /// Note that this function parses only one of the two similar parts, with the
464 /// keyword provided as argument.
465 static ParseResult
466 parseAttributions(OpAsmParser &parser, StringRef keyword,
468  // If we could not parse the keyword, just assume empty list and succeed.
469  if (failed(parser.parseOptionalKeyword(keyword)))
470  return success();
471 
473  /*allowType=*/true);
474 }
475 
476 /// Prints a GPU function memory attribution.
477 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
478  ArrayRef<BlockArgument> values) {
479  if (values.empty())
480  return;
481 
482  p << ' ' << keyword << '(';
483  llvm::interleaveComma(
484  values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); });
485  p << ')';
486 }
487 
488 /// Verifies a GPU function memory attribution.
489 static LogicalResult verifyAttributions(Operation *op,
490  ArrayRef<BlockArgument> attributions,
491  gpu::AddressSpace memorySpace) {
492  for (Value v : attributions) {
493  auto type = llvm::dyn_cast<MemRefType>(v.getType());
494  if (!type)
495  return op->emitOpError() << "expected memref type in attribution";
496 
497  // We can only verify the address space if it hasn't already been lowered
498  // from the AddressSpaceAttr to a target-specific numeric value.
499  auto addressSpace =
500  llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
501  if (!addressSpace)
502  continue;
503  if (addressSpace.getValue() != memorySpace)
504  return op->emitOpError()
505  << "expected memory space " << stringifyAddressSpace(memorySpace)
506  << " in attribution";
507  }
508  return success();
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // AllReduceOp
513 //===----------------------------------------------------------------------===//
514 
515 static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
516  Type resType) {
517  using Kind = gpu::AllReduceOperation;
518  if (llvm::is_contained(
519  {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
520  opName)) {
521  if (!isa<FloatType>(resType))
522  return failure();
523  }
524 
525  if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
526  Kind::AND, Kind::OR, Kind::XOR},
527  opName)) {
528  if (!isa<IntegerType>(resType))
529  return failure();
530  }
531 
532  return success();
533 }
534 
535 LogicalResult gpu::AllReduceOp::verifyRegions() {
536  if (getBody().empty() != getOp().has_value())
537  return emitError("expected either an op attribute or a non-empty body");
538  if (!getBody().empty()) {
539  if (getBody().getNumArguments() != 2)
540  return emitError("expected two region arguments");
541  for (auto argument : getBody().getArguments()) {
542  if (argument.getType() != getType())
543  return emitError("incorrect region argument type");
544  }
545  unsigned yieldCount = 0;
546  for (Block &block : getBody()) {
547  if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
548  if (yield.getNumOperands() != 1)
549  return emitError("expected one gpu.yield operand");
550  if (yield.getOperand(0).getType() != getType())
551  return emitError("incorrect gpu.yield type");
552  ++yieldCount;
553  }
554  }
555  if (yieldCount == 0)
556  return emitError("expected gpu.yield op in region");
557  } else {
558  gpu::AllReduceOperation opName = *getOp();
559  if (failed(verifyReduceOpAndType(opName, getType()))) {
560  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
561  << "` reduction operation is not compatible with type "
562  << getType();
563  }
564  }
565 
566  return success();
567 }
568 
570  auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
571  if (!launchOp)
572  return false;
573 
574  Region &body = launchOp.getBody();
575  assert(!body.empty() && "Invalid region");
576 
577  // Only convert ops in gpu::launch entry block for now.
578  return op->getBlock() == &body.front();
579 }
580 
581 OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
582  if (!getUniform() && canMakeGroupOpUniform(*this)) {
583  setUniform(true);
584  return getResult();
585  }
586 
587  return nullptr;
588 }
589 
590 // TODO: Support optional custom attributes (without dialect prefix).
591 static ParseResult parseAllReduceOperation(AsmParser &parser,
592  AllReduceOperationAttr &attr) {
593  StringRef enumStr;
594  if (!parser.parseOptionalKeyword(&enumStr)) {
595  std::optional<AllReduceOperation> op =
596  gpu::symbolizeAllReduceOperation(enumStr);
597  if (!op)
598  return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
599  attr = AllReduceOperationAttr::get(parser.getContext(), *op);
600  }
601  return success();
602 }
603 
604 static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
605  AllReduceOperationAttr attr) {
606  if (attr)
607  attr.print(printer);
608 }
609 
610 //===----------------------------------------------------------------------===//
611 // SubgroupReduceOp
612 //===----------------------------------------------------------------------===//
613 
614 LogicalResult gpu::SubgroupReduceOp::verify() {
615  Type elemType = getType();
616  if (auto vecTy = dyn_cast<VectorType>(elemType)) {
617  if (vecTy.isScalable())
618  return emitOpError() << "is not compatible with scalable vector types";
619 
620  elemType = vecTy.getElementType();
621  }
622 
623  gpu::AllReduceOperation opName = getOp();
624  if (failed(verifyReduceOpAndType(opName, elemType))) {
625  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
626  << "` reduction operation is not compatible with type "
627  << getType();
628  }
629 
630  auto clusterSize = getClusterSize();
631  if (clusterSize) {
632  uint32_t size = *clusterSize;
633  if (!llvm::isPowerOf2_32(size)) {
634  return emitOpError() << "cluster size " << size
635  << " is not a power of two";
636  }
637  }
638 
639  uint32_t stride = getClusterStride();
640  if (stride != 1 && !clusterSize) {
641  return emitOpError() << "cluster stride can only be specified if cluster "
642  "size is specified";
643  }
644  if (!llvm::isPowerOf2_32(stride)) {
645  return emitOpError() << "cluster stride " << stride
646  << " is not a power of two";
647  }
648 
649  return success();
650 }
651 
652 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
653  if (getClusterSize() == 1)
654  return getValue();
655 
656  if (!getUniform() && canMakeGroupOpUniform(*this)) {
657  setUniform(true);
658  return getResult();
659  }
660 
661  return nullptr;
662 }
663 
664 //===----------------------------------------------------------------------===//
665 // AsyncOpInterface
666 //===----------------------------------------------------------------------===//
667 
669  op->insertOperands(0, {token});
670  if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
671  return;
672  auto attrName =
674  auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
675 
676  // Async dependencies is the only variadic operand.
677  if (!sizeAttr)
678  return;
679 
680  SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
681  ++sizes.front();
682  op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
683 }
684 
685 //===----------------------------------------------------------------------===//
686 // LaunchOp
687 //===----------------------------------------------------------------------===//
688 
689 void LaunchOp::build(OpBuilder &builder, OperationState &result,
690  Value gridSizeX, Value gridSizeY, Value gridSizeZ,
691  Value getBlockSizeX, Value getBlockSizeY,
692  Value getBlockSizeZ, Value dynamicSharedMemorySize,
693  Type asyncTokenType, ValueRange asyncDependencies,
694  TypeRange workgroupAttributions,
695  TypeRange privateAttributions, Value clusterSizeX,
696  Value clusterSizeY, Value clusterSizeZ) {
697  OpBuilder::InsertionGuard g(builder);
698 
699  // Add a WorkGroup attribution attribute. This attribute is required to
700  // identify private attributions in the list of block argguments.
701  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
702  builder.getI64IntegerAttr(workgroupAttributions.size()));
703 
704  // Add Op operands.
705  result.addOperands(asyncDependencies);
706  if (asyncTokenType)
707  result.types.push_back(builder.getType<AsyncTokenType>());
708 
709  // Add grid and block sizes as op operands, followed by the data operands.
710  result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
711  getBlockSizeY, getBlockSizeZ});
712  if (clusterSizeX)
713  result.addOperands(clusterSizeX);
714  if (clusterSizeY)
715  result.addOperands(clusterSizeY);
716  if (clusterSizeZ)
717  result.addOperands(clusterSizeZ);
718  if (dynamicSharedMemorySize)
719  result.addOperands(dynamicSharedMemorySize);
720 
721  // Create a kernel body region with kNumConfigRegionAttributes + N memory
722  // attributions, where the first kNumConfigRegionAttributes arguments have
723  // `index` type and the rest have the same types as the data operands.
724  Region *kernelRegion = result.addRegion();
725  Block *body = builder.createBlock(kernelRegion);
726  // TODO: Allow passing in proper locations here.
727  for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
728  body->addArgument(builder.getIndexType(), result.location);
729  // Add WorkGroup & Private attributions to the region arguments.
730  for (Type argTy : workgroupAttributions)
731  body->addArgument(argTy, result.location);
732  for (Type argTy : privateAttributions)
733  body->addArgument(argTy, result.location);
734  // Fill OperandSegmentSize Attribute.
735  SmallVector<int32_t, 11> segmentSizes(11, 1);
736  segmentSizes.front() = asyncDependencies.size();
737  segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
738  segmentSizes[7] = clusterSizeX ? 1 : 0;
739  segmentSizes[8] = clusterSizeY ? 1 : 0;
740  segmentSizes[9] = clusterSizeZ ? 1 : 0;
741  result.addAttribute(getOperandSegmentSizeAttr(),
742  builder.getDenseI32ArrayAttr(segmentSizes));
743 }
744 
745 KernelDim3 LaunchOp::getBlockIds() {
746  assert(!getBody().empty() && "LaunchOp body must not be empty.");
747  auto args = getBody().getArguments();
748  return KernelDim3{args[0], args[1], args[2]};
749 }
750 
751 KernelDim3 LaunchOp::getThreadIds() {
752  assert(!getBody().empty() && "LaunchOp body must not be empty.");
753  auto args = getBody().getArguments();
754  return KernelDim3{args[3], args[4], args[5]};
755 }
756 
757 KernelDim3 LaunchOp::getGridSize() {
758  assert(!getBody().empty() && "LaunchOp body must not be empty.");
759  auto args = getBody().getArguments();
760  return KernelDim3{args[6], args[7], args[8]};
761 }
762 
764  assert(!getBody().empty() && "LaunchOp body must not be empty.");
765  auto args = getBody().getArguments();
766  return KernelDim3{args[9], args[10], args[11]};
767 }
768 
769 std::optional<KernelDim3> LaunchOp::getClusterIds() {
770  assert(!getBody().empty() && "LaunchOp body must not be empty.");
771  if (!hasClusterSize())
772  return std::nullopt;
773  auto args = getBody().getArguments();
774  return KernelDim3{args[12], args[13], args[14]};
775 }
776 
777 std::optional<KernelDim3> LaunchOp::getClusterSize() {
778  assert(!getBody().empty() && "LaunchOp body must not be empty.");
779  if (!hasClusterSize())
780  return std::nullopt;
781  auto args = getBody().getArguments();
782  return KernelDim3{args[15], args[16], args[17]};
783 }
784 
785 KernelDim3 LaunchOp::getGridSizeOperandValues() {
786  auto operands = getOperands().drop_front(getAsyncDependencies().size());
787  return KernelDim3{operands[0], operands[1], operands[2]};
788 }
789 
790 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
791  auto operands = getOperands().drop_front(getAsyncDependencies().size());
792  return KernelDim3{operands[3], operands[4], operands[5]};
793 }
794 
795 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
796  auto operands = getOperands().drop_front(getAsyncDependencies().size());
797  if (!hasClusterSize())
798  return std::nullopt;
799  return KernelDim3{operands[6], operands[7], operands[8]};
800 }
801 
802 LogicalResult LaunchOp::verify() {
803  if (!(hasClusterSize()) &&
804  (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
805  return emitOpError() << "cluster size must be all present";
806  return success();
807 }
808 
809 LogicalResult LaunchOp::verifyRegions() {
810  // Kernel launch takes kNumConfigOperands leading operands for grid/block
811  // sizes and transforms them into kNumConfigRegionAttributes region arguments
812  // for block/thread identifiers and grid/block sizes.
813  if (!getBody().empty()) {
814  if (getBody().getNumArguments() <
815  kNumConfigRegionAttributes + getNumWorkgroupAttributions())
816  return emitOpError("unexpected number of region arguments");
817  }
818 
819  // Verify Attributions Address Spaces.
820  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
821  GPUDialect::getWorkgroupAddressSpace())) ||
822  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
823  GPUDialect::getPrivateAddressSpace())))
824  return failure();
825 
826  // Block terminators without successors are expected to exit the kernel region
827  // and must be `gpu.terminator`.
828  for (Block &block : getBody()) {
829  if (block.empty())
830  continue;
831  if (block.back().getNumSuccessors() != 0)
832  continue;
833  if (!isa<gpu::TerminatorOp>(&block.back())) {
834  return block.back()
835  .emitError()
836  .append("expected '", gpu::TerminatorOp::getOperationName(),
837  "' or a terminator with successors")
838  .attachNote(getLoc())
839  .append("in '", LaunchOp::getOperationName(), "' body region");
840  }
841  }
842 
843  if (getNumResults() == 0 && getAsyncToken())
844  return emitOpError("needs to be named when async keyword is specified");
845 
846  return success();
847 }
848 
849 // Pretty-print the kernel grid/block size assignment as
850 // (%iter-x, %iter-y, %iter-z) in
851 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
852 // where %size-* and %iter-* will correspond to the body region arguments.
854  KernelDim3 operands, KernelDim3 ids) {
855  p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
856  p << size.x << " = " << operands.x << ", ";
857  p << size.y << " = " << operands.y << ", ";
858  p << size.z << " = " << operands.z << ')';
859 }
860 
861 void LaunchOp::print(OpAsmPrinter &p) {
862  if (getAsyncToken()) {
863  p << " async";
864  if (!getAsyncDependencies().empty())
865  p << " [" << getAsyncDependencies() << ']';
866  }
867  // Print the launch configuration.
868  if (hasClusterSize()) {
869  p << ' ' << getClustersKeyword();
870  printSizeAssignment(p, getClusterSize().value(),
871  getClusterSizeOperandValues().value(),
872  getClusterIds().value());
873  }
874  p << ' ' << getBlocksKeyword();
875  printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
876  getBlockIds());
877  p << ' ' << getThreadsKeyword();
878  printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
879  getThreadIds());
880  if (getDynamicSharedMemorySize())
881  p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
882  << getDynamicSharedMemorySize();
883 
884  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
885  printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
886 
887  p << ' ';
888 
889  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
890  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
891  LaunchOp::getOperandSegmentSizeAttr(),
892  getNumWorkgroupAttributionsAttrName()});
893 }
894 
895 // Parse the size assignment blocks for blocks and threads. These have the form
896 // (%region_arg, %region_arg, %region_arg) in
897 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
898 // where %region_arg are percent-identifiers for the region arguments to be
899 // introduced further (SSA defs), and %operand are percent-identifiers for the
900 // SSA value uses.
901 static ParseResult
906  assert(indices.size() == 3 && "space for three indices expected");
909  /*allowResultNumber=*/false) ||
910  parser.parseKeyword("in") || parser.parseLParen())
911  return failure();
912  std::move(args.begin(), args.end(), indices.begin());
913 
914  for (int i = 0; i < 3; ++i) {
915  if (i != 0 && parser.parseComma())
916  return failure();
917  if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
918  parser.parseEqual() || parser.parseOperand(sizes[i]))
919  return failure();
920  }
921 
922  return parser.parseRParen();
923 }
924 
925 /// Parses a Launch operation.
926 /// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
927 /// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
928 /// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
929 /// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
930 /// memory-attribution
931 /// region attr-dict?
932 /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
933 ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
934  // Sizes of the grid and block.
936  sizes(LaunchOp::kNumConfigOperands);
937 
938  // Actual (data) operands passed to the kernel.
940 
941  // Region arguments to be created.
943  LaunchOp::kNumConfigRegionAttributes);
944 
945  // Parse optional async dependencies.
947  Type asyncTokenType;
948  if (failed(
949  parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
950  parser.resolveOperands(asyncDependencies, asyncTokenType,
951  result.operands))
952  return failure();
953  if (parser.getNumResults() > 0)
954  result.types.push_back(asyncTokenType);
955 
956  bool hasCluster = false;
957  if (succeeded(
958  parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) {
959  hasCluster = true;
960  sizes.resize(9);
961  regionArgs.resize(18);
962  }
964  MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
965 
966  // Last three segment assigns the cluster size. In the region argument
967  // list, this is last 6 arguments.
968  if (hasCluster) {
969  if (parseSizeAssignment(parser, sizesRef.drop_front(6),
970  regionArgsRef.slice(15, 3),
971  regionArgsRef.slice(12, 3)))
972  return failure();
973  }
974  // Parse the size assignment segments: the first segment assigns grid sizes
975  // and defines values for block identifiers; the second segment assigns block
976  // sizes and defines values for thread identifiers. In the region argument
977  // list, identifiers precede sizes, and block-related values precede
978  // thread-related values.
979  if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
980  parseSizeAssignment(parser, sizesRef.take_front(3),
981  regionArgsRef.slice(6, 3),
982  regionArgsRef.slice(0, 3)) ||
983  parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
984  parseSizeAssignment(parser, sizesRef.drop_front(3),
985  regionArgsRef.slice(9, 3),
986  regionArgsRef.slice(3, 3)) ||
987  parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
988  result.operands))
989  return failure();
990 
991  OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
992  bool hasDynamicSharedMemorySize = false;
993  if (!parser.parseOptionalKeyword(
994  LaunchOp::getDynamicSharedMemorySizeKeyword())) {
995  hasDynamicSharedMemorySize = true;
996  if (parser.parseOperand(dynamicSharedMemorySize) ||
997  parser.resolveOperand(dynamicSharedMemorySize,
998  parser.getBuilder().getI32Type(),
999  result.operands))
1000  return failure();
1001  }
1002 
1003  // Create the region arguments, it has kNumConfigRegionAttributes arguments
1004  // that correspond to block/thread identifiers and grid/block sizes, all
1005  // having `index` type, a variadic number of WorkGroup Attributions and
1006  // a variadic number of Private Attributions. The number of WorkGroup
1007  // Attributions is stored in the attr with name:
1008  // LaunchOp::getNumWorkgroupAttributionsAttrName().
1009  Type index = parser.getBuilder().getIndexType();
1011  LaunchOp::kNumConfigRegionAttributes + 6, index);
1012 
1013  SmallVector<OpAsmParser::Argument> regionArguments;
1014  for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1016  arg.ssaName = std::get<0>(ssaValueAndType);
1017  arg.type = std::get<1>(ssaValueAndType);
1018  regionArguments.push_back(arg);
1019  }
1020 
1021  Builder &builder = parser.getBuilder();
1022  // Parse workgroup memory attributions.
1023  if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
1024  regionArguments)))
1025  return failure();
1026 
1027  // Store the number of operands we just parsed as the number of workgroup
1028  // memory attributions.
1029  unsigned numWorkgroupAttrs = regionArguments.size() -
1030  LaunchOp::kNumConfigRegionAttributes -
1031  (hasCluster ? 6 : 0);
1032  result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1033  builder.getI64IntegerAttr(numWorkgroupAttrs));
1034 
1035  // Parse private memory attributions.
1036  if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
1037  regionArguments)))
1038  return failure();
1039 
1040  // Introduce the body region and parse it. The region has
1041  // kNumConfigRegionAttributes arguments that correspond to
1042  // block/thread identifiers and grid/block sizes, all having `index` type.
1043  Region *body = result.addRegion();
1044  if (parser.parseRegion(*body, regionArguments) ||
1045  parser.parseOptionalAttrDict(result.attributes))
1046  return failure();
1047 
1048  SmallVector<int32_t, 11> segmentSizes(11, 1);
1049  segmentSizes.front() = asyncDependencies.size();
1050 
1051  if (!hasCluster) {
1052  segmentSizes[7] = 0;
1053  segmentSizes[8] = 0;
1054  segmentSizes[9] = 0;
1055  }
1056  segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1057  result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1058  parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1059  return success();
1060 }
1061 
1062 /// Simplify the gpu.launch when the range of a thread or block ID is
1063 /// trivially known to be one.
1064 struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1066  LogicalResult matchAndRewrite(LaunchOp op,
1067  PatternRewriter &rewriter) const override {
1068  // If the range implies a single value for `id`, replace `id`'s uses by
1069  // zero.
1070  Value zero;
1071  bool simplified = false;
1072  auto constPropIdUses = [&](Value id, Value size) {
1073  // Check if size is trivially one.
1074  if (!matchPattern(size, m_One()))
1075  return;
1076  if (id.getUses().empty())
1077  return;
1078  if (!simplified) {
1079  // Create a zero value the first time.
1080  OpBuilder::InsertionGuard guard(rewriter);
1081  rewriter.setInsertionPointToStart(&op.getBody().front());
1082  zero =
1083  rewriter.create<arith::ConstantIndexOp>(op.getLoc(), /*value=*/0);
1084  }
1085  rewriter.replaceAllUsesWith(id, zero);
1086  simplified = true;
1087  };
1088  constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1089  constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1090  constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1091  constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1092  constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1093  constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1094 
1095  return success(simplified);
1096  }
1097 };
1098 
1099 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1100  MLIRContext *context) {
1101  rewrites.add<FoldLaunchArguments>(context);
1102 }
1103 
1104 /// Adds a new block argument that corresponds to buffers located in
1105 /// workgroup memory.
1106 BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1107  auto attrName = getNumWorkgroupAttributionsAttrName();
1108  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1109  (*this)->setAttr(attrName,
1110  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1111  return getBody().insertArgument(
1112  LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1113 }
1114 
1115 /// Adds a new block argument that corresponds to buffers located in
1116 /// private memory.
1117 BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1118  // Buffers on the private memory always come after buffers on the workgroup
1119  // memory.
1120  return getBody().addArgument(type, loc);
1121 }
1122 
1123 //===----------------------------------------------------------------------===//
1124 // LaunchFuncOp
1125 //===----------------------------------------------------------------------===//
1126 
1127 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1128  SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1129  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1130  ValueRange kernelOperands, Type asyncTokenType,
1131  ValueRange asyncDependencies,
1132  std::optional<KernelDim3> clusterSize) {
1133  assert(kernelSymbol.getNestedReferences().size() == 1 &&
1134  "expected a symbol reference with a single nested reference");
1135  result.addOperands(asyncDependencies);
1136  if (asyncTokenType)
1137  result.types.push_back(builder.getType<AsyncTokenType>());
1138 
1139  // Add grid and block sizes as op operands, followed by the data operands.
1140  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1141  getBlockSize.y, getBlockSize.z});
1142  if (clusterSize.has_value())
1143  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1144  if (dynamicSharedMemorySize)
1145  result.addOperands(dynamicSharedMemorySize);
1146  result.addOperands(kernelOperands);
1147 
1148  Properties &prop = result.getOrAddProperties<Properties>();
1149  prop.kernel = kernelSymbol;
1150  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1151  // Initialize the segment sizes to 1.
1152  for (auto &sz : prop.operandSegmentSizes)
1153  sz = 1;
1154  prop.operandSegmentSizes[0] = asyncDependencies.size();
1155  if (!clusterSize.has_value()) {
1156  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1157  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1158  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1159  }
1160  prop.operandSegmentSizes[segmentSizesLen - 3] =
1161  dynamicSharedMemorySize ? 1 : 0;
1162  prop.operandSegmentSizes[segmentSizesLen - 2] =
1163  static_cast<int32_t>(kernelOperands.size());
1164  prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1165 }
1166 
1167 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1168  GPUFuncOp kernelFunc, KernelDim3 gridSize,
1169  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1170  ValueRange kernelOperands, Type asyncTokenType,
1171  ValueRange asyncDependencies,
1172  std::optional<KernelDim3> clusterSize) {
1173  auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1174  auto kernelSymbol =
1175  SymbolRefAttr::get(kernelModule.getNameAttr(),
1176  {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1177  build(builder, result, kernelSymbol, gridSize, getBlockSize,
1178  dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1179  asyncDependencies, clusterSize);
1180 }
1181 
1182 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1183  SymbolRefAttr kernel, KernelDim3 gridSize,
1184  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1185  ValueRange kernelOperands, Value asyncObject,
1186  std::optional<KernelDim3> clusterSize) {
1187  // Add grid and block sizes as op operands, followed by the data operands.
1188  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1189  getBlockSize.y, getBlockSize.z});
1190  if (clusterSize.has_value())
1191  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1192  if (dynamicSharedMemorySize)
1193  result.addOperands(dynamicSharedMemorySize);
1194  result.addOperands(kernelOperands);
1195  if (asyncObject)
1196  result.addOperands(asyncObject);
1197  Properties &prop = result.getOrAddProperties<Properties>();
1198  prop.kernel = kernel;
1199  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1200  // Initialize the segment sizes to 1.
1201  for (auto &sz : prop.operandSegmentSizes)
1202  sz = 1;
1203  prop.operandSegmentSizes[0] = 0;
1204  if (!clusterSize.has_value()) {
1205  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1206  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1207  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1208  }
1209  prop.operandSegmentSizes[segmentSizesLen - 3] =
1210  dynamicSharedMemorySize ? 1 : 0;
1211  prop.operandSegmentSizes[segmentSizesLen - 2] =
1212  static_cast<int32_t>(kernelOperands.size());
1213  prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1214 }
1215 
1216 StringAttr LaunchFuncOp::getKernelModuleName() {
1217  return getKernel().getRootReference();
1218 }
1219 
1220 StringAttr LaunchFuncOp::getKernelName() {
1221  return getKernel().getLeafReference();
1222 }
1223 
1224 unsigned LaunchFuncOp::getNumKernelOperands() {
1225  return getKernelOperands().size();
1226 }
1227 
1228 Value LaunchFuncOp::getKernelOperand(unsigned i) {
1229  return getKernelOperands()[i];
1230 }
1231 
1232 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1233  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1234  return KernelDim3{operands[0], operands[1], operands[2]};
1235 }
1236 
1237 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1238  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1239  return KernelDim3{operands[3], operands[4], operands[5]};
1240 }
1241 
1242 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1243  assert(hasClusterSize() &&
1244  "cluster size is not set, check hasClusterSize() first");
1245  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1246  return KernelDim3{operands[6], operands[7], operands[8]};
1247 }
1248 
1249 LogicalResult LaunchFuncOp::verify() {
1250  auto module = (*this)->getParentOfType<ModuleOp>();
1251  if (!module)
1252  return emitOpError("expected to belong to a module");
1253 
1254  if (!module->getAttrOfType<UnitAttr>(
1255  GPUDialect::getContainerModuleAttrName()))
1256  return emitOpError("expected the closest surrounding module to have the '" +
1257  GPUDialect::getContainerModuleAttrName() +
1258  "' attribute");
1259 
1260  if (hasClusterSize()) {
1261  if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1262  getClusterSizeZ().getType() != getClusterSizeX().getType())
1263  return emitOpError()
1264  << "expects types of the cluster dimensions must be the same";
1265  }
1266 
1267  return success();
1268 }
1269 
1270 static ParseResult
1272  std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1273  Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1274  if (succeeded(parser.parseOptionalColon())) {
1275  if (parser.parseType(dimTy))
1276  return failure();
1277  } else {
1278  dimTy = IndexType::get(parser.getContext());
1279  }
1280  if (clusterValue.has_value()) {
1281  clusterXTy = clusterYTy = clusterZTy = dimTy;
1282  }
1283  return success();
1284 }
1285 
1286 static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1287  Value clusterValue, Type clusterXTy,
1288  Type clusterYTy, Type clusterZTy) {
1289  if (!dimTy.isIndex())
1290  printer << ": " << dimTy;
1291 }
1292 
1293 static ParseResult parseLaunchFuncOperands(
1294  OpAsmParser &parser,
1296  SmallVectorImpl<Type> &argTypes) {
1297  if (parser.parseOptionalKeyword("args"))
1298  return success();
1299 
1300  auto parseElement = [&]() -> ParseResult {
1301  return failure(parser.parseOperand(argNames.emplace_back()) ||
1302  parser.parseColonType(argTypes.emplace_back()));
1303  };
1304 
1306  parseElement, " in argument list");
1307 }
1308 
1310  OperandRange operands, TypeRange types) {
1311  if (operands.empty())
1312  return;
1313  printer << "args(";
1314  llvm::interleaveComma(llvm::zip(operands, types), printer,
1315  [&](const auto &pair) {
1316  printer.printOperand(std::get<0>(pair));
1317  printer << " : ";
1318  printer.printType(std::get<1>(pair));
1319  });
1320  printer << ")";
1321 }
1322 
1323 //===----------------------------------------------------------------------===//
1324 // ShuffleOp
1325 //===----------------------------------------------------------------------===//
1326 
1327 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1328  int32_t offset, int32_t width, ShuffleMode mode) {
1329  build(builder, result, value,
1330  builder.create<arith::ConstantOp>(result.location,
1331  builder.getI32IntegerAttr(offset)),
1332  builder.create<arith::ConstantOp>(result.location,
1333  builder.getI32IntegerAttr(width)),
1334  mode);
1335 }
1336 
1337 //===----------------------------------------------------------------------===//
1338 // BarrierOp
1339 //===----------------------------------------------------------------------===//
1340 
1341 namespace {
1342 
1343 /// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1344 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1345  PatternRewriter &rewriter) {
1346  if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1347  rewriter.eraseOp(op);
1348  return success();
1349  }
1350  return failure();
1351 }
1352 
1353 } // end anonymous namespace
1354 
1355 void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1356  MLIRContext *context) {
1357  results.add(eraseRedundantGpuBarrierOps);
1358 }
1359 
1360 //===----------------------------------------------------------------------===//
1361 // GPUFuncOp
1362 //===----------------------------------------------------------------------===//
1363 
1364 /// Adds a new block argument that corresponds to buffers located in
1365 /// workgroup memory.
1366 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1367  auto attrName = getNumWorkgroupAttributionsAttrName();
1368  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1369  (*this)->setAttr(attrName,
1370  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1371  return getBody().insertArgument(
1372  getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1373 }
1374 
1375 /// Adds a new block argument that corresponds to buffers located in
1376 /// private memory.
1377 BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1378  // Buffers on the private memory always come after buffers on the workgroup
1379  // memory.
1380  return getBody().addArgument(type, loc);
1381 }
1382 
1383 void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1384  StringRef name, FunctionType type,
1385  TypeRange workgroupAttributions,
1386  TypeRange privateAttributions,
1387  ArrayRef<NamedAttribute> attrs) {
1388  OpBuilder::InsertionGuard g(builder);
1389 
1391  builder.getStringAttr(name));
1392  result.addAttribute(getFunctionTypeAttrName(result.name),
1393  TypeAttr::get(type));
1394  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1395  builder.getI64IntegerAttr(workgroupAttributions.size()));
1396  result.addAttributes(attrs);
1397  Region *body = result.addRegion();
1398  Block *entryBlock = builder.createBlock(body);
1399 
1400  // TODO: Allow passing in proper locations here.
1401  for (Type argTy : type.getInputs())
1402  entryBlock->addArgument(argTy, result.location);
1403  for (Type argTy : workgroupAttributions)
1404  entryBlock->addArgument(argTy, result.location);
1405  for (Type argTy : privateAttributions)
1406  entryBlock->addArgument(argTy, result.location);
1407 }
1408 
1409 /// Parses a GPU function memory attribution.
1410 ///
1411 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1412 /// (`private` `(` ssa-id-and-type-list `)`)?
1413 ///
1414 /// Note that this function parses only one of the two similar parts, with the
1415 /// keyword provided as argument.
1416 static ParseResult
1417 parseAttributions(OpAsmParser &parser, StringRef keyword,
1419  Attribute &attributionAttrs) {
1420  // If we could not parse the keyword, just assume empty list and succeed.
1421  if (failed(parser.parseOptionalKeyword(keyword)))
1422  return success();
1423 
1424  size_t existingArgs = args.size();
1425  ParseResult result =
1427  /*allowType=*/true, /*allowAttrs=*/true);
1428  if (failed(result))
1429  return result;
1430 
1431  bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1432  [](const OpAsmParser::Argument &arg) -> bool {
1433  return arg.attrs && !arg.attrs.empty();
1434  });
1435  if (!hadAttrs) {
1436  attributionAttrs = nullptr;
1437  return result;
1438  }
1439 
1440  Builder &builder = parser.getBuilder();
1441  SmallVector<Attribute> attributionAttrsVec;
1442  for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1443  if (!argument.attrs)
1444  attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1445  else
1446  attributionAttrsVec.push_back(argument.attrs);
1447  }
1448  attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1449  return result;
1450 }
1451 
1452 /// Parses a GPU function.
1453 ///
1454 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1455 /// (`->` function-result-list)? memory-attribution `kernel`?
1456 /// function-attributes? region
1457 ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
1459  SmallVector<DictionaryAttr> resultAttrs;
1460  SmallVector<Type> resultTypes;
1461  bool isVariadic;
1462 
1463  // Parse the function name.
1464  StringAttr nameAttr;
1465  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1466  result.attributes))
1467  return failure();
1468 
1469  auto signatureLocation = parser.getCurrentLocation();
1471  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1472  resultAttrs)))
1473  return failure();
1474 
1475  if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1476  return parser.emitError(signatureLocation)
1477  << "gpu.func requires named arguments";
1478 
1479  // Construct the function type. More types will be added to the region, but
1480  // not to the function type.
1481  Builder &builder = parser.getBuilder();
1482 
1483  SmallVector<Type> argTypes;
1484  for (auto &arg : entryArgs)
1485  argTypes.push_back(arg.type);
1486  auto type = builder.getFunctionType(argTypes, resultTypes);
1487  result.addAttribute(getFunctionTypeAttrName(result.name),
1488  TypeAttr::get(type));
1489 
1491  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1492  getResAttrsAttrName(result.name));
1493 
1494  Attribute workgroupAttributionAttrs;
1495  // Parse workgroup memory attributions.
1496  if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1497  entryArgs, workgroupAttributionAttrs)))
1498  return failure();
1499 
1500  // Store the number of operands we just parsed as the number of workgroup
1501  // memory attributions.
1502  unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1503  result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1504  builder.getI64IntegerAttr(numWorkgroupAttrs));
1505  if (workgroupAttributionAttrs)
1506  result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1507  workgroupAttributionAttrs);
1508 
1509  Attribute privateAttributionAttrs;
1510  // Parse private memory attributions.
1511  if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1512  entryArgs, privateAttributionAttrs)))
1513  return failure();
1514  if (privateAttributionAttrs)
1515  result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1516  privateAttributionAttrs);
1517 
1518  // Parse the kernel attribute if present.
1519  if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1520  result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1521  builder.getUnitAttr());
1522 
1523  // Parse attributes.
1524  if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1525  return failure();
1526 
1527  // Parse the region. If no argument names were provided, take all names
1528  // (including those of attributions) from the entry block.
1529  auto *body = result.addRegion();
1530  return parser.parseRegion(*body, entryArgs);
1531 }
1532 
1533 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
1534  ArrayRef<BlockArgument> values,
1535  ArrayAttr attributes) {
1536  if (values.empty())
1537  return;
1538 
1539  p << ' ' << keyword << '(';
1540  llvm::interleaveComma(
1541  llvm::enumerate(values), p, [&p, attributes](auto pair) {
1542  BlockArgument v = pair.value();
1543  p << v << " : " << v.getType();
1544 
1545  size_t attributionIndex = pair.index();
1546  DictionaryAttr attrs;
1547  if (attributes && attributionIndex < attributes.size())
1548  attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1549  if (attrs)
1550  p.printOptionalAttrDict(attrs.getValue());
1551  });
1552  p << ')';
1553 }
1554 
1555 void GPUFuncOp::print(OpAsmPrinter &p) {
1556  p << ' ';
1557  p.printSymbolName(getName());
1558 
1559  FunctionType type = getFunctionType();
1560  function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1561  /*isVariadic=*/false,
1562  type.getResults());
1563 
1564  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1565  getWorkgroupAttribAttrs().value_or(nullptr));
1566  printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1567  getPrivateAttribAttrs().value_or(nullptr));
1568  if (isKernel())
1569  p << ' ' << getKernelKeyword();
1570 
1572  p, *this,
1573  {getNumWorkgroupAttributionsAttrName(),
1574  GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1575  getArgAttrsAttrName(), getResAttrsAttrName(),
1576  getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1577  p << ' ';
1578  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1579 }
1580 
1581 static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1582  StringAttr attrName) {
1583  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1584  if (!allAttrs || index >= allAttrs.size())
1585  return DictionaryAttr();
1586  return llvm::cast<DictionaryAttr>(allAttrs[index]);
1587 }
1588 
1589 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1590  return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1591 }
1592 
1593 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1594  return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1595 }
1596 
1597 static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1598  DictionaryAttr value, StringAttr attrName) {
1599  MLIRContext *ctx = op.getContext();
1600  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1601  SmallVector<Attribute> elements;
1602  if (allAttrs)
1603  elements.append(allAttrs.begin(), allAttrs.end());
1604  while (elements.size() <= index)
1605  elements.push_back(DictionaryAttr::get(ctx));
1606  if (!value)
1607  elements[index] = DictionaryAttr::get(ctx);
1608  else
1609  elements[index] = value;
1610  ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1611  op->setAttr(attrName, newValue);
1612 }
1613 
1614 void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1615  DictionaryAttr value) {
1616  setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1617 }
1618 
1619 void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1620  DictionaryAttr value) {
1621  setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1622 }
1623 
1624 static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1625  StringAttr name, StringAttr attrsName) {
1626  DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1627  if (!dict)
1628  return Attribute();
1629  return dict.get(name);
1630 }
1631 
1632 Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1633  StringAttr name) {
1634  assert(index < getNumWorkgroupAttributions() &&
1635  "index must map to a workgroup attribution");
1636  return getAttributionAttr(*this, index, name,
1637  getWorkgroupAttribAttrsAttrName());
1638 }
1639 
1640 Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1641  StringAttr name) {
1642  assert(index < getNumPrivateAttributions() &&
1643  "index must map to a private attribution");
1644  return getAttributionAttr(*this, index, name,
1645  getPrivateAttribAttrsAttrName());
1646 }
1647 
1648 static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1649  Attribute value, StringAttr attrsName) {
1650  MLIRContext *ctx = op.getContext();
1652  DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1653  if (oldDict)
1654  elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1655 
1656  bool found = false;
1657  bool mustSort = true;
1658  for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1659  if (elems[i].getName() == name) {
1660  found = true;
1661  if (!value) {
1662  std::swap(elems[i], elems[elems.size() - 1]);
1663  elems.pop_back();
1664  } else {
1665  mustSort = false;
1666  elems[i] = NamedAttribute(elems[i].getName(), value);
1667  }
1668  break;
1669  }
1670  }
1671  if (!found) {
1672  if (!value)
1673  return;
1674  elems.emplace_back(name, value);
1675  }
1676  if (mustSort) {
1677  DictionaryAttr::sortInPlace(elems);
1678  }
1679  auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1680  setAttributionAttrs(op, index, newDict, attrsName);
1681 }
1682 
1683 void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1684  Attribute value) {
1685  assert(index < getNumWorkgroupAttributions() &&
1686  "index must map to a workgroup attribution");
1687  setAttributionAttr(*this, index, name, value,
1688  getWorkgroupAttribAttrsAttrName());
1689 }
1690 
1691 void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1692  Attribute value) {
1693  assert(index < getNumPrivateAttributions() &&
1694  "index must map to a private attribution");
1695  setAttributionAttr(*this, index, name, value,
1696  getPrivateAttribAttrsAttrName());
1697 }
1698 
1699 LogicalResult GPUFuncOp::verifyType() {
1700  if (isKernel() && getFunctionType().getNumResults() != 0)
1701  return emitOpError() << "expected void return type for kernel function";
1702 
1703  return success();
1704 }
1705 
1706 /// Verifies the body of the function.
1707 LogicalResult GPUFuncOp::verifyBody() {
1708  if (empty())
1709  return emitOpError() << "expected body with at least one block";
1710  unsigned numFuncArguments = getNumArguments();
1711  unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1712  unsigned numBlockArguments = front().getNumArguments();
1713  if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1714  return emitOpError() << "expected at least "
1715  << numFuncArguments + numWorkgroupAttributions
1716  << " arguments to body region";
1717 
1718  ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1719  for (unsigned i = 0; i < numFuncArguments; ++i) {
1720  Type blockArgType = front().getArgument(i).getType();
1721  if (funcArgTypes[i] != blockArgType)
1722  return emitOpError() << "expected body region argument #" << i
1723  << " to be of type " << funcArgTypes[i] << ", got "
1724  << blockArgType;
1725  }
1726 
1727  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1728  GPUDialect::getWorkgroupAddressSpace())) ||
1729  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1730  GPUDialect::getPrivateAddressSpace())))
1731  return failure();
1732 
1733  return success();
1734 }
1735 
1736 //===----------------------------------------------------------------------===//
1737 // ReturnOp
1738 //===----------------------------------------------------------------------===//
1739 
1740 LogicalResult gpu::ReturnOp::verify() {
1741  GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1742 
1743  FunctionType funType = function.getFunctionType();
1744 
1745  if (funType.getNumResults() != getOperands().size())
1746  return emitOpError()
1747  .append("expected ", funType.getNumResults(), " result operands")
1748  .attachNote(function.getLoc())
1749  .append("return type declared here");
1750 
1751  for (const auto &pair : llvm::enumerate(
1752  llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1753  auto [type, operand] = pair.value();
1754  if (type != operand.getType())
1755  return emitOpError() << "unexpected type `" << operand.getType()
1756  << "' for operand #" << pair.index();
1757  }
1758  return success();
1759 }
1760 
1761 //===----------------------------------------------------------------------===//
1762 // GPUModuleOp
1763 //===----------------------------------------------------------------------===//
1764 
1765 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1766  StringRef name, ArrayAttr targets,
1767  Attribute offloadingHandler) {
1768  result.addRegion()->emplaceBlock();
1769  Properties &props = result.getOrAddProperties<Properties>();
1770  if (targets)
1771  props.targets = targets;
1772  props.setSymName(builder.getStringAttr(name));
1773  props.offloadingHandler = offloadingHandler;
1774 }
1775 
1776 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1777  StringRef name, ArrayRef<Attribute> targets,
1778  Attribute offloadingHandler) {
1779  build(builder, result, name,
1780  targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1781  offloadingHandler);
1782 }
1783 
1784 bool GPUModuleOp::hasTarget(Attribute target) {
1785  if (ArrayAttr targets = getTargetsAttr())
1786  return llvm::count(targets.getValue(), target);
1787  return false;
1788 }
1789 
1790 void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1791  ArrayAttr &targetsAttr = getProperties().targets;
1792  SmallVector<Attribute> targetsVector(targets);
1793  targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1794 }
1795 
1796 //===----------------------------------------------------------------------===//
1797 // GPUBinaryOp
1798 //===----------------------------------------------------------------------===//
1799 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1800  Attribute offloadingHandler, ArrayAttr objects) {
1801  auto &properties = result.getOrAddProperties<Properties>();
1802  result.attributes.push_back(builder.getNamedAttr(
1803  SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1804  properties.objects = objects;
1805  if (offloadingHandler)
1806  properties.offloadingHandler = offloadingHandler;
1807  else
1808  properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
1809 }
1810 
1811 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1812  Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1813  build(builder, result, name, offloadingHandler,
1814  objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
1815 }
1816 
1817 static ParseResult parseOffloadingHandler(OpAsmParser &parser,
1818  Attribute &offloadingHandler) {
1819  if (succeeded(parser.parseOptionalLess())) {
1820  if (parser.parseAttribute(offloadingHandler))
1821  return failure();
1822  if (parser.parseGreater())
1823  return failure();
1824  }
1825  if (!offloadingHandler)
1826  offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
1827  return success();
1828 }
1829 
1831  Attribute offloadingHandler) {
1832  if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
1833  printer << '<' << offloadingHandler << '>';
1834 }
1835 
1836 //===----------------------------------------------------------------------===//
1837 // GPUMemcpyOp
1838 //===----------------------------------------------------------------------===//
1839 
1840 LogicalResult MemcpyOp::verify() {
1841  auto srcType = getSrc().getType();
1842  auto dstType = getDst().getType();
1843 
1844  if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1845  return emitOpError("arguments have incompatible element type");
1846 
1847  if (failed(verifyCompatibleShape(srcType, dstType)))
1848  return emitOpError("arguments have incompatible shape");
1849 
1850  return success();
1851 }
1852 
1853 namespace {
1854 
1855 /// Erases a common case of copy ops where a destination value is used only by
1856 /// the copy op, alloc and dealloc ops.
1857 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1858  using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1859 
1860  LogicalResult matchAndRewrite(MemcpyOp op,
1861  PatternRewriter &rewriter) const override {
1862  Value dest = op.getDst();
1863  Operation *destDefOp = dest.getDefiningOp();
1864  // `dest` must be defined by an op having Allocate memory effect in order to
1865  // perform the folding.
1866  if (!destDefOp ||
1867  !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1868  return failure();
1869  // We can erase `op` iff `dest` has no other use apart from its
1870  // use by `op` and dealloc ops.
1871  if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1872  return user != op &&
1873  !hasSingleEffect<MemoryEffects::Free>(user, dest);
1874  }))
1875  return failure();
1876  // We can perform the folding if and only if op has a single async
1877  // dependency and produces an async token as result, or if it does not have
1878  // any async dependency and does not produce any async token result.
1879  if (op.getAsyncDependencies().size() > 1 ||
1880  ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1881  (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1882  return failure();
1883  rewriter.replaceOp(op, op.getAsyncDependencies());
1884  return success();
1885  }
1886 };
1887 
1888 } // end anonymous namespace
1889 
1890 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1891  MLIRContext *context) {
1892  results.add<EraseTrivialCopyOp>(context);
1893 }
1894 
1895 //===----------------------------------------------------------------------===//
1896 // GPU_SubgroupMmaLoadMatrixOp
1897 //===----------------------------------------------------------------------===//
1898 
1899 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1900  auto srcType = getSrcMemref().getType();
1901  auto resType = getRes().getType();
1902  auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
1903  auto operand = resMatrixType.getOperand();
1904  auto srcMemrefType = llvm::cast<MemRefType>(srcType);
1905 
1906  if (!srcMemrefType.isLastDimUnitStride())
1907  return emitError(
1908  "expected source memref most minor dim must have unit stride");
1909 
1910  if (operand != "AOp" && operand != "BOp" && operand != "COp")
1911  return emitError("only AOp, BOp and COp can be loaded");
1912 
1913  return success();
1914 }
1915 
1916 //===----------------------------------------------------------------------===//
1917 // GPU_SubgroupMmaStoreMatrixOp
1918 //===----------------------------------------------------------------------===//
1919 
1920 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1921  auto srcType = getSrc().getType();
1922  auto dstType = getDstMemref().getType();
1923  auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
1924  auto dstMemrefType = llvm::cast<MemRefType>(dstType);
1925 
1926  if (!dstMemrefType.isLastDimUnitStride())
1927  return emitError(
1928  "expected destination memref most minor dim must have unit stride");
1929 
1930  if (srcMatrixType.getOperand() != "COp")
1931  return emitError(
1932  "expected the operand matrix being stored to have 'COp' operand type");
1933 
1934  return success();
1935 }
1936 
1937 //===----------------------------------------------------------------------===//
1938 // GPU_SubgroupMmaComputeOp
1939 //===----------------------------------------------------------------------===//
1940 
1941 LogicalResult SubgroupMmaComputeOp::verify() {
1942  enum OperandMap { A, B, C };
1943  SmallVector<MMAMatrixType, 3> opTypes;
1944  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
1945  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
1946  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
1947 
1948  if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
1949  opTypes[C].getOperand() != "COp")
1950  return emitError("operands must be in the order AOp, BOp, COp");
1951 
1952  ArrayRef<int64_t> aShape, bShape, cShape;
1953  aShape = opTypes[A].getShape();
1954  bShape = opTypes[B].getShape();
1955  cShape = opTypes[C].getShape();
1956 
1957  if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1958  bShape[1] != cShape[1])
1959  return emitError("operand shapes do not satisfy matmul constraints");
1960 
1961  return success();
1962 }
1963 
1964 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
1965  SmallVectorImpl<::mlir::OpFoldResult> &results) {
1966  return memref::foldMemRefCast(*this);
1967 }
1968 
1969 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
1970  SmallVectorImpl<::mlir::OpFoldResult> &results) {
1971  return memref::foldMemRefCast(*this);
1972 }
1973 
1974 //===----------------------------------------------------------------------===//
1975 // GPU_WaitOp
1976 //===----------------------------------------------------------------------===//
1977 
1978 namespace {
1979 
1980 /// Remove gpu.wait op use of gpu.wait op def without async dependencies.
1981 /// %t = gpu.wait async [] // No async dependencies.
1982 /// ... gpu.wait ... [%t, ...] // %t can be removed.
1983 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
1984 public:
1985  using OpRewritePattern::OpRewritePattern;
1986 
1987  LogicalResult matchAndRewrite(WaitOp op,
1988  PatternRewriter &rewriter) const final {
1989  auto predicate = [](Value value) {
1990  auto waitOp = value.getDefiningOp<WaitOp>();
1991  return waitOp && waitOp->getNumOperands() == 0;
1992  };
1993  if (llvm::none_of(op.getAsyncDependencies(), predicate))
1994  return failure();
1995  SmallVector<Value> validOperands;
1996  for (Value operand : op->getOperands()) {
1997  if (predicate(operand))
1998  continue;
1999  validOperands.push_back(operand);
2000  }
2001  rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2002  return success();
2003  }
2004 };
2005 
2006 /// Simplify trivial gpu.wait ops for the following patterns.
2007 /// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2008 /// dependencies).
2009 /// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2010 /// %t0.
2011 /// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2012 /// dependencies nor return any token.
2013 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2014 public:
2015  using OpRewritePattern::OpRewritePattern;
2016 
2017  LogicalResult matchAndRewrite(WaitOp op,
2018  PatternRewriter &rewriter) const final {
2019  // Erase gpu.wait ops that neither have any async dependencies nor return
2020  // any async token.
2021  if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2022  rewriter.eraseOp(op);
2023  return success();
2024  }
2025  // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2026  if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2027  op.getAsyncToken()) {
2028  rewriter.replaceOp(op, op.getAsyncDependencies());
2029  return success();
2030  }
2031  // Erase %t = gpu.wait async ... ops, where %t has no uses.
2032  if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2033  rewriter.eraseOp(op);
2034  return success();
2035  }
2036  return failure();
2037  }
2038 };
2039 
2040 } // end anonymous namespace
2041 
2042 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2043  MLIRContext *context) {
2044  results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2045 }
2046 
2047 //===----------------------------------------------------------------------===//
2048 // GPU_AllocOp
2049 //===----------------------------------------------------------------------===//
2050 
2051 LogicalResult AllocOp::verify() {
2052  auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2053 
2054  if (getDynamicSizes().size() != memRefType.getNumDynamicDims())
2055  return emitOpError("dimension operand count does not equal memref "
2056  "dynamic dimension count");
2057 
2058  unsigned numSymbols = 0;
2059  if (!memRefType.getLayout().isIdentity())
2060  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2061  if (getSymbolOperands().size() != numSymbols) {
2062  return emitOpError(
2063  "symbol operand count does not equal memref symbol count");
2064  }
2065 
2066  return success();
2067 }
2068 
2069 namespace {
2070 
2071 /// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2072 /// `memref::AllocOp`.
2073 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2074  using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2075 
2076  LogicalResult matchAndRewrite(memref::DimOp dimOp,
2077  PatternRewriter &rewriter) const override {
2078  std::optional<int64_t> index = dimOp.getConstantIndex();
2079  if (!index)
2080  return failure();
2081 
2082  auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2083  if (!memrefType || index.value() >= memrefType.getRank() ||
2084  !memrefType.isDynamicDim(index.value()))
2085  return failure();
2086 
2087  auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2088  if (!alloc)
2089  return failure();
2090 
2091  Value substituteOp = *(alloc.getDynamicSizes().begin() +
2092  memrefType.getDynamicDimIndex(index.value()));
2093  rewriter.replaceOp(dimOp, substituteOp);
2094  return success();
2095  }
2096 };
2097 
2098 } // namespace
2099 
2100 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2101  MLIRContext *context) {
2102  results.add<SimplifyDimOfAllocOp>(context);
2103 }
2104 
2105 //===----------------------------------------------------------------------===//
2106 // GPU object attribute
2107 //===----------------------------------------------------------------------===//
2108 
2109 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2110  Attribute target, CompilationTarget format,
2111  StringAttr object, DictionaryAttr properties,
2112  KernelTableAttr kernels) {
2113  if (!target)
2114  return emitError() << "the target attribute cannot be null";
2115  if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2116  return success();
2117  return emitError() << "the target attribute must implement or promise the "
2118  "`gpu::TargetAttrInterface`";
2119 }
2120 
2121 namespace {
2122 ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2123  StringAttr &object) {
2124  std::optional<CompilationTarget> formatResult;
2125  StringRef enumKeyword;
2126  auto loc = odsParser.getCurrentLocation();
2127  if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2128  formatResult = CompilationTarget::Fatbin;
2129  if (!formatResult &&
2130  (formatResult =
2131  gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2132  odsParser.parseEqual())
2133  return odsParser.emitError(loc, "expected an equal sign");
2134  if (!formatResult)
2135  return odsParser.emitError(loc, "expected keyword for GPU object format");
2136  FailureOr<StringAttr> objectResult =
2137  FieldParser<StringAttr>::parse(odsParser);
2138  if (failed(objectResult))
2139  return odsParser.emitError(odsParser.getCurrentLocation(),
2140  "failed to parse GPU_ObjectAttr parameter "
2141  "'object' which is to be a `StringAttr`");
2142  format = *formatResult;
2143  object = *objectResult;
2144  return success();
2145 }
2146 
2147 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2148  StringAttr object) {
2149  if (format != CompilationTarget::Fatbin)
2150  odsParser << stringifyEnum(format) << " = ";
2151  odsParser << object;
2152 }
2153 } // namespace
2154 
2155 //===----------------------------------------------------------------------===//
2156 // GPU select object attribute
2157 //===----------------------------------------------------------------------===//
2158 
2159 LogicalResult
2160 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2161  Attribute target) {
2162  // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2163  if (target) {
2164  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2165  if (intAttr.getInt() < 0) {
2166  return emitError() << "the object index must be positive";
2167  }
2168  } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2169  return emitError()
2170  << "the target attribute must be a GPU Target attribute";
2171  }
2172  }
2173  return success();
2174 }
2175 
2176 //===----------------------------------------------------------------------===//
2177 // DynamicSharedMemoryOp
2178 //===----------------------------------------------------------------------===//
2179 
2180 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2181  if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2182  return emitOpError() << "must be inside an op with symbol table";
2183 
2184  MemRefType memrefType = getResultMemref().getType();
2185  // Check address space
2186  if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2187  return emitOpError() << "address space must be "
2188  << gpu::AddressSpaceAttr::getMnemonic() << "<"
2189  << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2190  }
2191  if (memrefType.hasStaticShape()) {
2192  return emitOpError() << "result memref type must be memref<?xi8, "
2193  "#gpu.address_space<workgroup>>";
2194  }
2195  return success();
2196 }
2197 
2198 //===----------------------------------------------------------------------===//
2199 // GPU WarpExecuteOnLane0Op
2200 //===----------------------------------------------------------------------===//
2201 
2202 void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2203  p << "(" << getLaneid() << ")";
2204 
2205  SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2206  auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2207  p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2208 
2209  if (!getArgs().empty())
2210  p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2211  if (!getResults().empty())
2212  p << " -> (" << getResults().getTypes() << ')';
2213  p << " ";
2214  p.printRegion(getRegion(),
2215  /*printEntryBlockArgs=*/true,
2216  /*printBlockTerminators=*/!getResults().empty());
2217  p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2218 }
2219 
2220 ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2221  OperationState &result) {
2222  // Create the region.
2223  result.regions.reserve(1);
2224  Region *warpRegion = result.addRegion();
2225 
2226  auto &builder = parser.getBuilder();
2227  OpAsmParser::UnresolvedOperand laneId;
2228 
2229  // Parse predicate operand.
2230  if (parser.parseLParen() ||
2231  parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2232  parser.parseRParen())
2233  return failure();
2234 
2235  int64_t warpSize;
2236  if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2237  parser.parseRSquare())
2238  return failure();
2239  result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2240  builder.getContext())),
2241  builder.getI64IntegerAttr(warpSize));
2242 
2243  if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2244  return failure();
2245 
2246  llvm::SMLoc inputsOperandsLoc;
2247  SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2248  SmallVector<Type> inputTypes;
2249  if (succeeded(parser.parseOptionalKeyword("args"))) {
2250  if (parser.parseLParen())
2251  return failure();
2252 
2253  inputsOperandsLoc = parser.getCurrentLocation();
2254  if (parser.parseOperandList(inputsOperands) ||
2255  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2256  return failure();
2257  }
2258  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2259  result.operands))
2260  return failure();
2261 
2262  // Parse optional results type list.
2263  if (parser.parseOptionalArrowTypeList(result.types))
2264  return failure();
2265  // Parse the region.
2266  if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2267  /*argTypes=*/{}))
2268  return failure();
2269  WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2270 
2271  // Parse the optional attribute list.
2272  if (parser.parseOptionalAttrDict(result.attributes))
2273  return failure();
2274  return success();
2275 }
2276 
2277 void WarpExecuteOnLane0Op::getSuccessorRegions(
2278  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2279  if (!point.isParent()) {
2280  regions.push_back(RegionSuccessor(getResults()));
2281  return;
2282  }
2283 
2284  // The warp region is always executed
2285  regions.push_back(RegionSuccessor(&getWarpRegion()));
2286 }
2287 
2288 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2289  TypeRange resultTypes, Value laneId,
2290  int64_t warpSize) {
2291  build(builder, result, resultTypes, laneId, warpSize,
2292  /*operands=*/std::nullopt, /*argTypes=*/std::nullopt);
2293 }
2294 
2295 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2296  TypeRange resultTypes, Value laneId,
2297  int64_t warpSize, ValueRange args,
2298  TypeRange blockArgTypes) {
2299  result.addOperands(laneId);
2300  result.addAttribute(getAttributeNames()[0],
2301  builder.getI64IntegerAttr(warpSize));
2302  result.addTypes(resultTypes);
2303  result.addOperands(args);
2304  assert(args.size() == blockArgTypes.size());
2305  OpBuilder::InsertionGuard guard(builder);
2306  Region *warpRegion = result.addRegion();
2307  Block *block = builder.createBlock(warpRegion);
2308  for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2309  block->addArgument(type, arg.getLoc());
2310 }
2311 
2312 /// Helper check if the distributed vector type is consistent with the expanded
2313 /// type and distributed size.
2314 static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2315  int64_t warpSize, Operation *op) {
2316  // If the types matches there is no distribution.
2317  if (expanded == distributed)
2318  return success();
2319  auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2320  auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2321  if (!expandedVecType || !distributedVecType)
2322  return op->emitOpError("expected vector type for distributed operands.");
2323  if (expandedVecType.getRank() != distributedVecType.getRank() ||
2324  expandedVecType.getElementType() != distributedVecType.getElementType())
2325  return op->emitOpError(
2326  "expected distributed vectors to have same rank and element type.");
2327 
2328  SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2329  for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2330  int64_t eDim = expandedVecType.getDimSize(i);
2331  int64_t dDim = distributedVecType.getDimSize(i);
2332  if (eDim == dDim)
2333  continue;
2334  if (eDim % dDim != 0)
2335  return op->emitOpError()
2336  << "expected expanded vector dimension #" << i << " (" << eDim
2337  << ") to be a multipler of the distributed vector dimension ("
2338  << dDim << ")";
2339  scales[i] = eDim / dDim;
2340  }
2341  if (std::accumulate(scales.begin(), scales.end(), 1,
2342  std::multiplies<int64_t>()) != warpSize)
2343  return op->emitOpError()
2344  << "incompatible distribution dimensions from " << expandedVecType
2345  << " to " << distributedVecType << " with warp size = " << warpSize;
2346 
2347  return success();
2348 }
2349 
2350 LogicalResult WarpExecuteOnLane0Op::verify() {
2351  if (getArgs().size() != getWarpRegion().getNumArguments())
2352  return emitOpError(
2353  "expected same number op arguments and block arguments.");
2354  auto yield =
2355  cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
2356  if (yield.getNumOperands() != getNumResults())
2357  return emitOpError(
2358  "expected same number of yield operands and return values.");
2359  int64_t warpSize = getWarpSize();
2360  for (auto [regionArg, arg] :
2361  llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2362  if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2363  warpSize, getOperation())))
2364  return failure();
2365  }
2366  for (auto [yieldOperand, result] :
2367  llvm::zip_equal(yield.getOperands(), getResults())) {
2368  if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2369  warpSize, getOperation())))
2370  return failure();
2371  }
2372  return success();
2373 }
2374 bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2375  return succeeded(
2376  verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2377 }
2378 
2379 //===----------------------------------------------------------------------===//
2380 // GPU KernelMetadataAttr
2381 //===----------------------------------------------------------------------===//
2382 
2383 KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2384  DictionaryAttr metadata) {
2385  assert(kernel && "invalid kernel");
2386  return get(kernel.getNameAttr(), kernel.getFunctionType(),
2387  kernel.getAllArgAttrs(), metadata);
2388 }
2389 
2390 KernelMetadataAttr
2391 KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2392  FunctionOpInterface kernel,
2393  DictionaryAttr metadata) {
2394  assert(kernel && "invalid kernel");
2395  return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2396  kernel.getAllArgAttrs(), metadata);
2397 }
2398 
2399 KernelMetadataAttr
2400 KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2401  if (attrs.empty())
2402  return *this;
2403  NamedAttrList attrList;
2404  if (DictionaryAttr dict = getMetadata())
2405  attrList.append(dict);
2406  attrList.append(attrs);
2407  return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2408  attrList.getDictionary(getContext()));
2409 }
2410 
2411 LogicalResult
2412 KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2413  StringAttr name, Type functionType,
2414  ArrayAttr argAttrs, DictionaryAttr metadata) {
2415  if (name.empty())
2416  return emitError() << "the kernel name can't be empty";
2417  if (argAttrs) {
2418  if (llvm::any_of(argAttrs, [](Attribute attr) {
2419  return !llvm::isa<DictionaryAttr>(attr);
2420  }))
2421  return emitError()
2422  << "all attributes in the array must be a dictionary attribute";
2423  }
2424  return success();
2425 }
2426 
2427 //===----------------------------------------------------------------------===//
2428 // GPU KernelTableAttr
2429 //===----------------------------------------------------------------------===//
2430 
2431 KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2432  ArrayRef<KernelMetadataAttr> kernels,
2433  bool isSorted) {
2434  // Note that `is_sorted` is always only invoked once even with assertions ON.
2435  assert((!isSorted || llvm::is_sorted(kernels)) &&
2436  "expected a sorted kernel array");
2437  // Immediately return the attribute if the array is sorted.
2438  if (isSorted || llvm::is_sorted(kernels))
2439  return Base::get(context, kernels);
2440  // Sort the array.
2441  SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2442  llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2443  return Base::get(context, kernelsTmp);
2444 }
2445 
2446 KernelTableAttr KernelTableAttr::getChecked(
2447  function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2448  ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2449  // Note that `is_sorted` is always only invoked once even with assertions ON.
2450  assert((!isSorted || llvm::is_sorted(kernels)) &&
2451  "expected a sorted kernel array");
2452  // Immediately return the attribute if the array is sorted.
2453  if (isSorted || llvm::is_sorted(kernels))
2454  return Base::getChecked(emitError, context, kernels);
2455  // Sort the array.
2456  SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2457  llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2458  return Base::getChecked(emitError, context, kernelsTmp);
2459 }
2460 
2461 LogicalResult
2462 KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2463  ArrayRef<KernelMetadataAttr> kernels) {
2464  if (kernels.size() < 2)
2465  return success();
2466  // Check that the kernels are uniquely named.
2467  if (std::adjacent_find(kernels.begin(), kernels.end(),
2468  [](KernelMetadataAttr l, KernelMetadataAttr r) {
2469  return l.getName() == r.getName();
2470  }) != kernels.end()) {
2471  return emitError() << "expected all kernels to be uniquely named";
2472  }
2473  return success();
2474 }
2475 
2476 KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2477  auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2478  return found ? *iterator : KernelMetadataAttr();
2479 }
2480 
2481 KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2482  auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2483  return found ? *iterator : KernelMetadataAttr();
2484 }
2485 
2486 //===----------------------------------------------------------------------===//
2487 // GPU target options
2488 //===----------------------------------------------------------------------===//
2489 
2490 TargetOptions::TargetOptions(
2491  StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2492  StringRef cmdOptions, StringRef elfSection,
2493  CompilationTarget compilationTarget,
2494  function_ref<SymbolTable *()> getSymbolTableCallback,
2495  function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2496  function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2497  function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2498  function_ref<void(StringRef)> isaCallback)
2499  : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
2500  cmdOptions, elfSection, compilationTarget,
2501  getSymbolTableCallback, initialLlvmIRCallback,
2502  linkedLlvmIRCallback, optimizedLlvmIRCallback,
2503  isaCallback) {}
2504 
2505 TargetOptions::TargetOptions(
2506  TypeID typeID, StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2507  StringRef cmdOptions, StringRef elfSection,
2508  CompilationTarget compilationTarget,
2509  function_ref<SymbolTable *()> getSymbolTableCallback,
2510  function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2511  function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2512  function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2513  function_ref<void(StringRef)> isaCallback)
2514  : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
2515  cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
2516  compilationTarget(compilationTarget),
2517  getSymbolTableCallback(getSymbolTableCallback),
2518  initialLlvmIRCallback(initialLlvmIRCallback),
2519  linkedLlvmIRCallback(linkedLlvmIRCallback),
2520  optimizedLlvmIRCallback(optimizedLlvmIRCallback),
2521  isaCallback(isaCallback), typeID(typeID) {}
2522 
2523 TypeID TargetOptions::getTypeID() const { return typeID; }
2524 
2525 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2526 
2527 ArrayRef<Attribute> TargetOptions::getLibrariesToLink() const {
2528  return librariesToLink;
2529 }
2530 
2531 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2532 
2533 StringRef TargetOptions::getELFSection() const { return elfSection; }
2534 
2535 SymbolTable *TargetOptions::getSymbolTable() const {
2536  return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2537 }
2538 
2539 function_ref<void(llvm::Module &)>
2540 TargetOptions::getInitialLlvmIRCallback() const {
2541  return initialLlvmIRCallback;
2542 }
2543 
2544 function_ref<void(llvm::Module &)>
2545 TargetOptions::getLinkedLlvmIRCallback() const {
2546  return linkedLlvmIRCallback;
2547 }
2548 
2549 function_ref<void(llvm::Module &)>
2550 TargetOptions::getOptimizedLlvmIRCallback() const {
2551  return optimizedLlvmIRCallback;
2552 }
2553 
2554 function_ref<void(StringRef)> TargetOptions::getISACallback() const {
2555  return isaCallback;
2556 }
2557 
2558 CompilationTarget TargetOptions::getCompilationTarget() const {
2559  return compilationTarget;
2560 }
2561 
2562 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2563  return CompilationTarget::Fatbin;
2564 }
2565 
2566 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2567 TargetOptions::tokenizeCmdOptions(const std::string &cmdOptions) {
2568  std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2569  llvm::StringSaver stringSaver(options.first);
2570  StringRef opts = cmdOptions;
2571  // For a correct tokenization of the command line options `opts` must be
2572  // unquoted, otherwise the tokenization function returns a single string: the
2573  // unquoted `cmdOptions` -which is not the desired behavior.
2574  // Remove any quotes if they are at the beginning and end of the string:
2575  if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2576  opts.consume_front("\""), opts.consume_back("\"");
2577  if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2578  opts.consume_front("'"), opts.consume_back("'");
2579 #ifdef _WIN32
2580  llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2581  /*MarkEOLs=*/false);
2582 #else
2583  llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2584  /*MarkEOLs=*/false);
2585 #endif // _WIN32
2586  return options;
2587 }
2588 
2589 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2592 }
2593 
2594 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2596  size_t startPos = cmdOptions.find(startsWith);
2597  if (startPos == std::string::npos)
2598  return {llvm::BumpPtrAllocator(), SmallVector<const char *>()};
2599 
2600  auto tokenized =
2601  tokenizeCmdOptions(cmdOptions.substr(startPos + startsWith.size()));
2602  cmdOptions.resize(startPos);
2603  return tokenized;
2604 }
2605 
2607 
2608 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2609 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2610 
2611 #define GET_ATTRDEF_CLASSES
2612 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2613 
2614 #define GET_OP_CLASSES
2615 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2616 
2617 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
Definition: GPUDialect.cpp:427
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
Definition: GPUDialect.cpp:591
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
Definition: GPUDialect.cpp:443
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
Definition: GPUDialect.cpp:902
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)
Prints a GPU function memory attribution.
Definition: GPUDialect.cpp:477
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
Definition: GPUDialect.cpp:569
static std::string getSparseHandleKeyword(SparseHandleKind kind)
Definition: GPUDialect.cpp:227
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
Definition: GPUDialect.cpp:317
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
Definition: GPUDialect.cpp:604
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
Definition: GPUDialect.cpp:466
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
Definition: GPUDialect.cpp:515
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
Definition: GPUDialect.cpp:853
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
Definition: GPUDialect.cpp:489
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
union mlir::linalg::@1179::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MINUI(lhs, rhs)
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:323
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
UnitAttr getUnitAttr()
Definition: Builders.cpp:94
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:76
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:89
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
IndexType getIndexType()
Definition: Builders.cpp:51
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition: Builders.cpp:100
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:90
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:222
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
Definition: Operation.cpp:256
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:865
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:656
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isIndex() const
Definition: Types.cpp:54
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static ConcreteT get(MLIRContext *ctx, Args &&...args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Definition: GPUDialect.cpp:138
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:123
Type getElementType() const
Get elementType of a single element.
Definition: GPUDialect.cpp:142
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
Definition: GPUDialect.cpp:146
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
Definition: GPUDialect.cpp:153
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Definition: GPUDialect.cpp:144
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
Definition: GPUDialect.cpp:129
unsigned getNumDims() const
Get number of dims.
Definition: GPUDialect.cpp:136
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
Definition: GPUDialect.cpp:668
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:39