MLIR  19.0.0git
GPUDialect.cpp
Go to the documentation of this file.
1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU kernel-related dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/SymbolTable.h"
29 #include "mlir/IR/TypeUtilities.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/TypeSwitch.h"
36 #include "llvm/Support/CommandLine.h"
37 #include "llvm/Support/ErrorHandling.h"
38 #include "llvm/Support/StringSaver.h"
39 #include <cassert>
40 
41 using namespace mlir;
42 using namespace mlir::gpu;
43 
44 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
45 
46 //===----------------------------------------------------------------------===//
47 // GPU Device Mapping Attributes
48 //===----------------------------------------------------------------------===//
49 
50 int64_t GPUBlockMappingAttr::getMappingId() const {
51  return static_cast<int64_t>(getBlock());
52 }
53 
54 bool GPUBlockMappingAttr::isLinearMapping() const {
55  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
56 }
57 
58 int64_t GPUBlockMappingAttr::getRelativeIndex() const {
59  return isLinearMapping()
60  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
61  : getMappingId();
62 }
63 
64 int64_t GPUWarpgroupMappingAttr::getMappingId() const {
65  return static_cast<int64_t>(getWarpgroup());
66 }
67 
68 bool GPUWarpgroupMappingAttr::isLinearMapping() const {
69  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
70 }
71 
72 int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
73  return isLinearMapping()
74  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
75  : getMappingId();
76 }
77 
78 int64_t GPUWarpMappingAttr::getMappingId() const {
79  return static_cast<int64_t>(getWarp());
80 }
81 
82 bool GPUWarpMappingAttr::isLinearMapping() const {
83  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
84 }
85 
86 int64_t GPUWarpMappingAttr::getRelativeIndex() const {
87  return isLinearMapping()
88  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
89  : getMappingId();
90 }
91 
92 int64_t GPUThreadMappingAttr::getMappingId() const {
93  return static_cast<int64_t>(getThread());
94 }
95 
96 bool GPUThreadMappingAttr::isLinearMapping() const {
97  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
98 }
99 
100 int64_t GPUThreadMappingAttr::getRelativeIndex() const {
101  return isLinearMapping()
102  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
103  : getMappingId();
104 }
105 
106 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
107  return static_cast<int64_t>(getAddressSpace());
108 }
109 
110 bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
111  llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
112 }
113 
114 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
115  llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
116 }
117 
118 //===----------------------------------------------------------------------===//
119 // MMAMatrixType
120 //===----------------------------------------------------------------------===//
121 
123  StringRef operand) {
124  return Base::get(elementType.getContext(), shape, elementType, operand);
125 }
126 
129  ArrayRef<int64_t> shape, Type elementType,
130  StringRef operand) {
131  return Base::getChecked(emitError, elementType.getContext(), shape,
132  elementType, operand);
133 }
134 
135 unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
136 
138  return getImpl()->getShape();
139 }
140 
141 Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
142 
143 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
144 
146  return elementType.isF16() || elementType.isF32() ||
147  elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
148  elementType.isInteger(32);
149 }
150 
153  ArrayRef<int64_t> shape, Type elementType,
154  StringRef operand) {
155  if (!operand.equals("AOp") && !operand.equals("BOp") &&
156  !operand.equals("COp"))
157  return emitError() << "operand expected to be one of AOp, BOp or COp";
158 
159  if (shape.size() != 2)
160  return emitError() << "MMAMatrixType must have exactly two dimensions";
161 
162  if (!MMAMatrixType::isValidElementType(elementType))
163  return emitError()
164  << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
165 
166  return success();
167 }
168 
169 //===----------------------------------------------------------------------===//
170 // GPUDialect
171 //===----------------------------------------------------------------------===//
172 
173 bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
174  if (!memorySpace)
175  return false;
176  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
177  return gpuAttr.getValue() == getWorkgroupAddressSpace();
178  return false;
179 }
180 
181 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
182  Attribute memorySpace = type.getMemorySpace();
183  return isWorkgroupMemoryAddressSpace(memorySpace);
184 }
185 
186 bool GPUDialect::isKernel(Operation *op) {
187  UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
188  return static_cast<bool>(isKernelAttr);
189 }
190 
191 namespace {
192 /// This class defines the interface for handling inlining with gpu
193 /// operations.
194 struct GPUInlinerInterface : public DialectInlinerInterface {
196 
197  /// All gpu dialect ops can be inlined.
198  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
199  return true;
200  }
201 };
202 } // namespace
203 
204 void GPUDialect::initialize() {
205  addTypes<AsyncTokenType>();
206  addTypes<MMAMatrixType>();
207  addTypes<SparseDnTensorHandleType>();
208  addTypes<SparseSpMatHandleType>();
209  addTypes<SparseSpGEMMOpHandleType>();
210  addOperations<
211 #define GET_OP_LIST
212 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
213  >();
214  addAttributes<
215 #define GET_ATTRDEF_LIST
216 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
217  >();
218  addInterfaces<GPUInlinerInterface>();
219  declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
220  TerminatorOp>();
221 }
222 
223 static std::string getSparseHandleKeyword(SparseHandleKind kind) {
224  switch (kind) {
226  return "sparse.dntensor_handle";
228  return "sparse.spmat_handle";
230  return "sparse.spgemmop_handle";
231  }
232  llvm_unreachable("unknown sparse handle kind");
233  return "";
234 }
235 
237  // Parse the main keyword for the type.
238  StringRef keyword;
239  if (parser.parseKeyword(&keyword))
240  return Type();
241  MLIRContext *context = getContext();
242 
243  // Handle 'async token' types.
244  if (keyword == "async.token")
245  return AsyncTokenType::get(context);
246 
247  if (keyword == "mma_matrix") {
248  SMLoc beginLoc = parser.getNameLoc();
249 
250  // Parse '<'.
251  if (parser.parseLess())
252  return nullptr;
253 
254  // Parse the size and elementType.
255  SmallVector<int64_t> shape;
256  Type elementType;
257  if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
258  parser.parseType(elementType))
259  return nullptr;
260 
261  // Parse ','
262  if (parser.parseComma())
263  return nullptr;
264 
265  // Parse operand.
266  std::string operand;
267  if (failed(parser.parseOptionalString(&operand)))
268  return nullptr;
269 
270  // Parse '>'.
271  if (parser.parseGreater())
272  return nullptr;
273 
275  parser.getEncodedSourceLoc(beginLoc)),
276  shape, elementType, operand);
277  }
278 
280  return SparseDnTensorHandleType::get(context);
282  return SparseSpMatHandleType::get(context);
284  return SparseSpGEMMOpHandleType::get(context);
285 
286  parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
287  return Type();
288 }
289 // TODO: print refined type here. Notice that should be corresponding to the
290 // parser
291 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
292  TypeSwitch<Type>(type)
293  .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
294  .Case<SparseDnTensorHandleType>([&](Type) {
296  })
297  .Case<SparseSpMatHandleType>(
299  .Case<SparseSpGEMMOpHandleType>([&](Type) {
301  })
302  .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
303  os << "mma_matrix<";
304  auto shape = fragTy.getShape();
305  for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
306  os << *dim << 'x';
307  os << shape.back() << 'x' << fragTy.getElementType();
308  os << ", \"" << fragTy.getOperand() << "\"" << '>';
309  })
310  .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
311 }
312 
313 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
314  NamedAttribute attr) {
315  if (!llvm::isa<UnitAttr>(attr.getValue()) ||
316  attr.getName() != getContainerModuleAttrName())
317  return success();
318 
319  auto module = dyn_cast<ModuleOp>(op);
320  if (!module)
321  return op->emitError("expected '")
322  << getContainerModuleAttrName() << "' attribute to be attached to '"
323  << ModuleOp::getOperationName() << '\'';
324 
325  auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
326  // Ignore launches that are nested more or less deep than functions in the
327  // module we are currently checking.
328  if (!launchOp->getParentOp() ||
329  launchOp->getParentOp()->getParentOp() != module)
330  return success();
331 
332  // Ignore launch ops with missing attributes here. The errors will be
333  // reported by the verifiers of those ops.
334  if (!launchOp->getAttrOfType<SymbolRefAttr>(
335  LaunchFuncOp::getKernelAttrName(launchOp->getName())))
336  return success();
337 
338  // Check that `launch_func` refers to a well-formed GPU kernel container.
339  StringAttr kernelContainerName = launchOp.getKernelModuleName();
340  Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
341  if (!kernelContainer)
342  return launchOp.emitOpError()
343  << "kernel container '" << kernelContainerName.getValue()
344  << "' is undefined";
345 
346  // If the container is a GPU binary op return success.
347  if (isa<BinaryOp>(kernelContainer))
348  return success();
349 
350  auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
351  if (!kernelModule)
352  return launchOp.emitOpError()
353  << "kernel module '" << kernelContainerName.getValue()
354  << "' is undefined";
355 
356  // Check that `launch_func` refers to a well-formed kernel function.
357  Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
358  if (!kernelFunc)
359  return launchOp.emitOpError("kernel function '")
360  << launchOp.getKernel() << "' is undefined";
361  auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
362  if (!kernelConvertedFunction) {
363  InFlightDiagnostic diag = launchOp.emitOpError()
364  << "referenced kernel '" << launchOp.getKernel()
365  << "' is not a function";
366  diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
367  return diag;
368  }
369 
370  if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
371  GPUDialect::getKernelFuncAttrName()))
372  return launchOp.emitOpError("kernel function is missing the '")
373  << GPUDialect::getKernelFuncAttrName() << "' attribute";
374 
375  // TODO: If the kernel isn't a GPU function (which happens during separate
376  // compilation), do not check type correspondence as it would require the
377  // verifier to be aware of the type conversion.
378  auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
379  if (!kernelGPUFunction)
380  return success();
381 
382  unsigned actualNumArguments = launchOp.getNumKernelOperands();
383  unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
384  if (expectedNumArguments != actualNumArguments)
385  return launchOp.emitOpError("got ")
386  << actualNumArguments << " kernel operands but expected "
387  << expectedNumArguments;
388 
389  auto functionType = kernelGPUFunction.getFunctionType();
390  for (unsigned i = 0; i < expectedNumArguments; ++i) {
391  if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
392  return launchOp.emitOpError("type of function argument ")
393  << i << " does not match";
394  }
395  }
396 
397  return success();
398  });
399 
400  return walkResult.wasInterrupted() ? failure() : success();
401 }
402 
403 /// Parses an optional list of async operands with an optional leading keyword.
404 /// (`async`)? (`[` ssa-id-list `]`)?
405 ///
406 /// This method is used by the tablegen assembly format for async ops as well.
408  OpAsmParser &parser, Type &asyncTokenType,
410  auto loc = parser.getCurrentLocation();
411  if (succeeded(parser.parseOptionalKeyword("async"))) {
412  if (parser.getNumResults() == 0)
413  return parser.emitError(loc, "needs to be named when marked 'async'");
414  asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
415  }
416  return parser.parseOperandList(asyncDependencies,
418 }
419 
420 /// Prints optional async dependencies with its leading keyword.
421 /// (`async`)? (`[` ssa-id-list `]`)?
422 // Used by the tablegen assembly format for several async ops.
424  Type asyncTokenType,
425  OperandRange asyncDependencies) {
426  if (asyncTokenType)
427  printer << "async";
428  if (asyncDependencies.empty())
429  return;
430  if (asyncTokenType)
431  printer << ' ';
432  printer << '[';
433  llvm::interleaveComma(asyncDependencies, printer);
434  printer << ']';
435 }
436 
437 // GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
438 /// Parses a GPU function memory attribution.
439 ///
440 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
441 /// (`private` `(` ssa-id-and-type-list `)`)?
442 ///
443 /// Note that this function parses only one of the two similar parts, with the
444 /// keyword provided as argument.
445 static ParseResult
446 parseAttributions(OpAsmParser &parser, StringRef keyword,
448  // If we could not parse the keyword, just assume empty list and succeed.
449  if (failed(parser.parseOptionalKeyword(keyword)))
450  return success();
451 
453  /*allowType=*/true);
454 }
455 
456 /// Prints a GPU function memory attribution.
457 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
458  ArrayRef<BlockArgument> values) {
459  if (values.empty())
460  return;
461 
462  p << ' ' << keyword << '(';
463  llvm::interleaveComma(
464  values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); });
465  p << ')';
466 }
467 
468 /// Verifies a GPU function memory attribution.
470  ArrayRef<BlockArgument> attributions,
471  gpu::AddressSpace memorySpace) {
472  for (Value v : attributions) {
473  auto type = llvm::dyn_cast<MemRefType>(v.getType());
474  if (!type)
475  return op->emitOpError() << "expected memref type in attribution";
476 
477  // We can only verify the address space if it hasn't already been lowered
478  // from the AddressSpaceAttr to a target-specific numeric value.
479  auto addressSpace =
480  llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
481  if (!addressSpace)
482  continue;
483  if (addressSpace.getValue() != memorySpace)
484  return op->emitOpError()
485  << "expected memory space " << stringifyAddressSpace(memorySpace)
486  << " in attribution";
487  }
488  return success();
489 }
490 
491 //===----------------------------------------------------------------------===//
492 // AllReduceOp
493 //===----------------------------------------------------------------------===//
494 
495 static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
496  Type resType) {
497  using Kind = gpu::AllReduceOperation;
498  if (llvm::is_contained(
499  {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
500  opName)) {
501  if (!isa<FloatType>(resType))
502  return failure();
503  }
504 
505  if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
506  Kind::AND, Kind::OR, Kind::XOR},
507  opName)) {
508  if (!isa<IntegerType>(resType))
509  return failure();
510  }
511 
512  return success();
513 }
514 
515 LogicalResult gpu::AllReduceOp::verifyRegions() {
516  if (getBody().empty() != getOp().has_value())
517  return emitError("expected either an op attribute or a non-empty body");
518  if (!getBody().empty()) {
519  if (getBody().getNumArguments() != 2)
520  return emitError("expected two region arguments");
521  for (auto argument : getBody().getArguments()) {
522  if (argument.getType() != getType())
523  return emitError("incorrect region argument type");
524  }
525  unsigned yieldCount = 0;
526  for (Block &block : getBody()) {
527  if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
528  if (yield.getNumOperands() != 1)
529  return emitError("expected one gpu.yield operand");
530  if (yield.getOperand(0).getType() != getType())
531  return emitError("incorrect gpu.yield type");
532  ++yieldCount;
533  }
534  }
535  if (yieldCount == 0)
536  return emitError("expected gpu.yield op in region");
537  } else {
538  gpu::AllReduceOperation opName = *getOp();
539  if (failed(verifyReduceOpAndType(opName, getType()))) {
540  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
541  << "` reduction operation is not compatible with type "
542  << getType();
543  }
544  }
545 
546  return success();
547 }
548 
550  auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
551  if (!launchOp)
552  return false;
553 
554  Region &body = launchOp.getBody();
555  assert(!body.empty() && "Invalid region");
556 
557  // Only convert ops in gpu::launch entry block for now.
558  return op->getBlock() == &body.front();
559 }
560 
561 OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
562  if (!getUniform() && canMakeGroupOpUniform(*this)) {
563  setUniform(true);
564  return getResult();
565  }
566 
567  return nullptr;
568 }
569 
570 // TODO: Support optional custom attributes (without dialect prefix).
572  AllReduceOperationAttr &attr) {
573  StringRef enumStr;
574  if (!parser.parseOptionalKeyword(&enumStr)) {
575  std::optional<AllReduceOperation> op =
576  gpu::symbolizeAllReduceOperation(enumStr);
577  if (!op)
578  return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
579  attr = AllReduceOperationAttr::get(parser.getContext(), *op);
580  }
581  return success();
582 }
583 
584 static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
585  AllReduceOperationAttr attr) {
586  if (attr)
587  attr.print(printer);
588 }
589 
590 //===----------------------------------------------------------------------===//
591 // SubgroupReduceOp
592 //===----------------------------------------------------------------------===//
593 
595  Type elemType = getType();
596  if (auto vecTy = dyn_cast<VectorType>(elemType)) {
597  if (vecTy.isScalable())
598  return emitOpError() << "is not compatible with scalable vector types";
599 
600  elemType = vecTy.getElementType();
601  }
602 
603  gpu::AllReduceOperation opName = getOp();
604  if (failed(verifyReduceOpAndType(opName, elemType))) {
605  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
606  << "` reduction operation is not compatible with type "
607  << getType();
608  }
609  return success();
610 }
611 
612 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
613  if (!getUniform() && canMakeGroupOpUniform(*this)) {
614  setUniform(true);
615  return getResult();
616  }
617 
618  return nullptr;
619 }
620 
621 //===----------------------------------------------------------------------===//
622 // AsyncOpInterface
623 //===----------------------------------------------------------------------===//
624 
626  op->insertOperands(0, {token});
627  if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
628  return;
629  auto attrName =
631  auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
632 
633  // Async dependencies is the only variadic operand.
634  if (!sizeAttr)
635  return;
636 
637  SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
638  ++sizes.front();
639  op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
640 }
641 
642 //===----------------------------------------------------------------------===//
643 // LaunchOp
644 //===----------------------------------------------------------------------===//
645 
646 void LaunchOp::build(OpBuilder &builder, OperationState &result,
647  Value gridSizeX, Value gridSizeY, Value gridSizeZ,
648  Value getBlockSizeX, Value getBlockSizeY,
649  Value getBlockSizeZ, Value dynamicSharedMemorySize,
650  Type asyncTokenType, ValueRange asyncDependencies,
651  TypeRange workgroupAttributions,
652  TypeRange privateAttributions, Value clusterSizeX,
653  Value clusterSizeY, Value clusterSizeZ) {
654  OpBuilder::InsertionGuard g(builder);
655 
656  // Add a WorkGroup attribution attribute. This attribute is required to
657  // identify private attributions in the list of block argguments.
658  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
659  builder.getI64IntegerAttr(workgroupAttributions.size()));
660 
661  // Add Op operands.
662  result.addOperands(asyncDependencies);
663  if (asyncTokenType)
664  result.types.push_back(builder.getType<AsyncTokenType>());
665 
666  // Add grid and block sizes as op operands, followed by the data operands.
667  result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
668  getBlockSizeY, getBlockSizeZ});
669  if (clusterSizeX)
670  result.addOperands(clusterSizeX);
671  if (clusterSizeY)
672  result.addOperands(clusterSizeY);
673  if (clusterSizeZ)
674  result.addOperands(clusterSizeZ);
675  if (dynamicSharedMemorySize)
676  result.addOperands(dynamicSharedMemorySize);
677 
678  // Create a kernel body region with kNumConfigRegionAttributes + N memory
679  // attributions, where the first kNumConfigRegionAttributes arguments have
680  // `index` type and the rest have the same types as the data operands.
681  Region *kernelRegion = result.addRegion();
682  Block *body = builder.createBlock(kernelRegion);
683  // TODO: Allow passing in proper locations here.
684  for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
685  body->addArgument(builder.getIndexType(), result.location);
686  // Add WorkGroup & Private attributions to the region arguments.
687  for (Type argTy : workgroupAttributions)
688  body->addArgument(argTy, result.location);
689  for (Type argTy : privateAttributions)
690  body->addArgument(argTy, result.location);
691  // Fill OperandSegmentSize Attribute.
692  SmallVector<int32_t, 11> segmentSizes(11, 1);
693  segmentSizes.front() = asyncDependencies.size();
694  segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
695  segmentSizes[7] = clusterSizeX ? 1 : 0;
696  segmentSizes[8] = clusterSizeY ? 1 : 0;
697  segmentSizes[9] = clusterSizeZ ? 1 : 0;
698  result.addAttribute(getOperandSegmentSizeAttr(),
699  builder.getDenseI32ArrayAttr(segmentSizes));
700 }
701 
702 KernelDim3 LaunchOp::getBlockIds() {
703  assert(!getBody().empty() && "LaunchOp body must not be empty.");
704  auto args = getBody().getArguments();
705  return KernelDim3{args[0], args[1], args[2]};
706 }
707 
708 KernelDim3 LaunchOp::getThreadIds() {
709  assert(!getBody().empty() && "LaunchOp body must not be empty.");
710  auto args = getBody().getArguments();
711  return KernelDim3{args[3], args[4], args[5]};
712 }
713 
714 KernelDim3 LaunchOp::getGridSize() {
715  assert(!getBody().empty() && "LaunchOp body must not be empty.");
716  auto args = getBody().getArguments();
717  return KernelDim3{args[6], args[7], args[8]};
718 }
719 
721  assert(!getBody().empty() && "LaunchOp body must not be empty.");
722  auto args = getBody().getArguments();
723  return KernelDim3{args[9], args[10], args[11]};
724 }
725 
726 std::optional<KernelDim3> LaunchOp::getClusterIds() {
727  assert(!getBody().empty() && "LaunchOp body must not be empty.");
728  if (!hasClusterSize())
729  return std::nullopt;
730  auto args = getBody().getArguments();
731  return KernelDim3{args[12], args[13], args[14]};
732 }
733 
734 std::optional<KernelDim3> LaunchOp::getClusterSize() {
735  assert(!getBody().empty() && "LaunchOp body must not be empty.");
736  if (!hasClusterSize())
737  return std::nullopt;
738  auto args = getBody().getArguments();
739  return KernelDim3{args[15], args[16], args[17]};
740 }
741 
742 KernelDim3 LaunchOp::getGridSizeOperandValues() {
743  auto operands = getOperands().drop_front(getAsyncDependencies().size());
744  return KernelDim3{operands[0], operands[1], operands[2]};
745 }
746 
747 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
748  auto operands = getOperands().drop_front(getAsyncDependencies().size());
749  return KernelDim3{operands[3], operands[4], operands[5]};
750 }
751 
752 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
753  auto operands = getOperands().drop_front(getAsyncDependencies().size());
754  if (!hasClusterSize())
755  return std::nullopt;
756  return KernelDim3{operands[6], operands[7], operands[8]};
757 }
758 
760  if (!(hasClusterSize()) &&
761  (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
762  return emitOpError() << "cluster size must be all present";
763  return success();
764 }
765 
766 LogicalResult LaunchOp::verifyRegions() {
767  // Kernel launch takes kNumConfigOperands leading operands for grid/block
768  // sizes and transforms them into kNumConfigRegionAttributes region arguments
769  // for block/thread identifiers and grid/block sizes.
770  if (!getBody().empty()) {
771  if (getBody().getNumArguments() <
772  kNumConfigRegionAttributes + getNumWorkgroupAttributions())
773  return emitOpError("unexpected number of region arguments");
774  }
775 
776  // Verify Attributions Address Spaces.
777  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
778  GPUDialect::getWorkgroupAddressSpace())) ||
779  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
780  GPUDialect::getPrivateAddressSpace())))
781  return failure();
782 
783  // Block terminators without successors are expected to exit the kernel region
784  // and must be `gpu.terminator`.
785  for (Block &block : getBody()) {
786  if (block.empty())
787  continue;
788  if (block.back().getNumSuccessors() != 0)
789  continue;
790  if (!isa<gpu::TerminatorOp>(&block.back())) {
791  return block.back()
792  .emitError()
793  .append("expected '", gpu::TerminatorOp::getOperationName(),
794  "' or a terminator with successors")
795  .attachNote(getLoc())
796  .append("in '", LaunchOp::getOperationName(), "' body region");
797  }
798  }
799 
800  if (getNumResults() == 0 && getAsyncToken())
801  return emitOpError("needs to be named when async keyword is specified");
802 
803  return success();
804 }
805 
806 // Pretty-print the kernel grid/block size assignment as
807 // (%iter-x, %iter-y, %iter-z) in
808 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
809 // where %size-* and %iter-* will correspond to the body region arguments.
811  KernelDim3 operands, KernelDim3 ids) {
812  p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
813  p << size.x << " = " << operands.x << ", ";
814  p << size.y << " = " << operands.y << ", ";
815  p << size.z << " = " << operands.z << ')';
816 }
817 
818 void LaunchOp::print(OpAsmPrinter &p) {
819  if (getAsyncToken()) {
820  p << " async";
821  if (!getAsyncDependencies().empty())
822  p << " [" << getAsyncDependencies() << ']';
823  }
824  // Print the launch configuration.
825  if (hasClusterSize()) {
826  p << ' ' << getClustersKeyword();
827  printSizeAssignment(p, getClusterSize().value(),
828  getClusterSizeOperandValues().value(),
829  getClusterIds().value());
830  }
831  p << ' ' << getBlocksKeyword();
832  printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
833  getBlockIds());
834  p << ' ' << getThreadsKeyword();
835  printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
836  getThreadIds());
837  if (getDynamicSharedMemorySize())
838  p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
839  << getDynamicSharedMemorySize();
840 
841  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
842  printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
843 
844  p << ' ';
845 
846  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
847  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
848  LaunchOp::getOperandSegmentSizeAttr(),
849  getNumWorkgroupAttributionsAttrName()});
850 }
851 
852 // Parse the size assignment blocks for blocks and threads. These have the form
853 // (%region_arg, %region_arg, %region_arg) in
854 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
855 // where %region_arg are percent-identifiers for the region arguments to be
856 // introduced further (SSA defs), and %operand are percent-identifiers for the
857 // SSA value uses.
858 static ParseResult
863  assert(indices.size() == 3 && "space for three indices expected");
866  /*allowResultNumber=*/false) ||
867  parser.parseKeyword("in") || parser.parseLParen())
868  return failure();
869  std::move(args.begin(), args.end(), indices.begin());
870 
871  for (int i = 0; i < 3; ++i) {
872  if (i != 0 && parser.parseComma())
873  return failure();
874  if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
875  parser.parseEqual() || parser.parseOperand(sizes[i]))
876  return failure();
877  }
878 
879  return parser.parseRParen();
880 }
881 
882 /// Parses a Launch operation.
883 /// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
884 /// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
885 /// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
886 /// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
887 /// memory-attribution
888 /// region attr-dict?
889 /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
891  // Sizes of the grid and block.
893  sizes(LaunchOp::kNumConfigOperands);
894 
895  // Actual (data) operands passed to the kernel.
897 
898  // Region arguments to be created.
900  LaunchOp::kNumConfigRegionAttributes);
901 
902  // Parse optional async dependencies.
904  Type asyncTokenType;
905  if (failed(
906  parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
907  parser.resolveOperands(asyncDependencies, asyncTokenType,
908  result.operands))
909  return failure();
910  if (parser.getNumResults() > 0)
911  result.types.push_back(asyncTokenType);
912 
913  bool hasCluster = false;
914  if (succeeded(
915  parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) {
916  hasCluster = true;
917  sizes.resize(9);
918  regionArgs.resize(18);
919  }
921  MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
922 
923  // Last three segment assigns the cluster size. In the region argument
924  // list, this is last 6 arguments.
925  if (hasCluster) {
926  if (parseSizeAssignment(parser, sizesRef.drop_front(6),
927  regionArgsRef.slice(15, 3),
928  regionArgsRef.slice(12, 3)))
929  return failure();
930  }
931  // Parse the size assignment segments: the first segment assigns grid sizes
932  // and defines values for block identifiers; the second segment assigns block
933  // sizes and defines values for thread identifiers. In the region argument
934  // list, identifiers precede sizes, and block-related values precede
935  // thread-related values.
936  if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
937  parseSizeAssignment(parser, sizesRef.take_front(3),
938  regionArgsRef.slice(6, 3),
939  regionArgsRef.slice(0, 3)) ||
940  parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
941  parseSizeAssignment(parser, sizesRef.drop_front(3),
942  regionArgsRef.slice(9, 3),
943  regionArgsRef.slice(3, 3)) ||
944  parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
945  result.operands))
946  return failure();
947 
948  OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
949  bool hasDynamicSharedMemorySize = false;
950  if (!parser.parseOptionalKeyword(
951  LaunchOp::getDynamicSharedMemorySizeKeyword())) {
952  hasDynamicSharedMemorySize = true;
953  if (parser.parseOperand(dynamicSharedMemorySize) ||
954  parser.resolveOperand(dynamicSharedMemorySize,
955  parser.getBuilder().getI32Type(),
956  result.operands))
957  return failure();
958  }
959 
960  // Create the region arguments, it has kNumConfigRegionAttributes arguments
961  // that correspond to block/thread identifiers and grid/block sizes, all
962  // having `index` type, a variadic number of WorkGroup Attributions and
963  // a variadic number of Private Attributions. The number of WorkGroup
964  // Attributions is stored in the attr with name:
965  // LaunchOp::getNumWorkgroupAttributionsAttrName().
966  Type index = parser.getBuilder().getIndexType();
968  LaunchOp::kNumConfigRegionAttributes + 6, index);
969 
970  SmallVector<OpAsmParser::Argument> regionArguments;
971  for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
973  arg.ssaName = std::get<0>(ssaValueAndType);
974  arg.type = std::get<1>(ssaValueAndType);
975  regionArguments.push_back(arg);
976  }
977 
978  Builder &builder = parser.getBuilder();
979  // Parse workgroup memory attributions.
980  if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
981  regionArguments)))
982  return failure();
983 
984  // Store the number of operands we just parsed as the number of workgroup
985  // memory attributions.
986  unsigned numWorkgroupAttrs = regionArguments.size() -
987  LaunchOp::kNumConfigRegionAttributes -
988  (hasCluster ? 6 : 0);
989  result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
990  builder.getI64IntegerAttr(numWorkgroupAttrs));
991 
992  // Parse private memory attributions.
993  if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
994  regionArguments)))
995  return failure();
996 
997  // Introduce the body region and parse it. The region has
998  // kNumConfigRegionAttributes arguments that correspond to
999  // block/thread identifiers and grid/block sizes, all having `index` type.
1000  Region *body = result.addRegion();
1001  if (parser.parseRegion(*body, regionArguments) ||
1002  parser.parseOptionalAttrDict(result.attributes))
1003  return failure();
1004 
1005  SmallVector<int32_t, 11> segmentSizes(11, 1);
1006  segmentSizes.front() = asyncDependencies.size();
1007 
1008  if (!hasCluster) {
1009  segmentSizes[7] = 0;
1010  segmentSizes[8] = 0;
1011  segmentSizes[9] = 0;
1012  }
1013  segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1014  result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1015  parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1016  return success();
1017 }
1018 
1019 /// Simplify the gpu.launch when the range of a thread or block ID is
1020 /// trivially known to be one.
1021 struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1024  PatternRewriter &rewriter) const override {
1025  // If the range implies a single value for `id`, replace `id`'s uses by
1026  // zero.
1027  Value zero;
1028  bool simplified = false;
1029  auto constPropIdUses = [&](Value id, Value size) {
1030  // Check if size is trivially one.
1031  if (!matchPattern(size, m_One()))
1032  return;
1033  if (id.getUses().empty())
1034  return;
1035  if (!simplified) {
1036  // Create a zero value the first time.
1037  OpBuilder::InsertionGuard guard(rewriter);
1038  rewriter.setInsertionPointToStart(&op.getBody().front());
1039  zero =
1040  rewriter.create<arith::ConstantIndexOp>(op.getLoc(), /*value=*/0);
1041  }
1042  rewriter.replaceAllUsesWith(id, zero);
1043  simplified = true;
1044  };
1045  constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1046  constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1047  constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1048  constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1049  constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1050  constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1051 
1052  return success(simplified);
1053  }
1054 };
1055 
1056 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1057  MLIRContext *context) {
1058  rewrites.add<FoldLaunchArguments>(context);
1059 }
1060 
1061 /// Adds a new block argument that corresponds to buffers located in
1062 /// workgroup memory.
1063 BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1064  auto attrName = getNumWorkgroupAttributionsAttrName();
1065  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1066  (*this)->setAttr(attrName,
1067  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1068  return getBody().insertArgument(
1069  LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1070 }
1071 
1072 /// Adds a new block argument that corresponds to buffers located in
1073 /// private memory.
1074 BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1075  // Buffers on the private memory always come after buffers on the workgroup
1076  // memory.
1077  return getBody().addArgument(type, loc);
1078 }
1079 
1080 //===----------------------------------------------------------------------===//
1081 // LaunchFuncOp
1082 //===----------------------------------------------------------------------===//
1083 
1084 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1085  GPUFuncOp kernelFunc, KernelDim3 gridSize,
1086  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1087  ValueRange kernelOperands, Type asyncTokenType,
1088  ValueRange asyncDependencies,
1089  std::optional<KernelDim3> clusterSize) {
1090  result.addOperands(asyncDependencies);
1091  if (asyncTokenType)
1092  result.types.push_back(builder.getType<AsyncTokenType>());
1093 
1094  // Add grid and block sizes as op operands, followed by the data operands.
1095  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1096  getBlockSize.y, getBlockSize.z});
1097  if (clusterSize.has_value())
1098  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1099  if (dynamicSharedMemorySize)
1100  result.addOperands(dynamicSharedMemorySize);
1101  result.addOperands(kernelOperands);
1102  auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1103  auto kernelSymbol =
1104  SymbolRefAttr::get(kernelModule.getNameAttr(),
1105  {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1106 
1107  Properties &prop = result.getOrAddProperties<Properties>();
1108  prop.kernel = kernelSymbol;
1109  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1110  // Initialize the segment sizes to 1.
1111  for (auto &sz : prop.operandSegmentSizes)
1112  sz = 1;
1113  prop.operandSegmentSizes[0] = asyncDependencies.size();
1114  if (!clusterSize.has_value()) {
1115  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1116  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1117  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1118  }
1119  prop.operandSegmentSizes[segmentSizesLen - 3] =
1120  dynamicSharedMemorySize ? 1 : 0;
1121  prop.operandSegmentSizes[segmentSizesLen - 2] =
1122  static_cast<int32_t>(kernelOperands.size());
1123  prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1124 }
1125 
1126 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1127  SymbolRefAttr kernel, KernelDim3 gridSize,
1128  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1129  ValueRange kernelOperands, Value asyncObject,
1130  std::optional<KernelDim3> clusterSize) {
1131  // Add grid and block sizes as op operands, followed by the data operands.
1132  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1133  getBlockSize.y, getBlockSize.z});
1134  if (clusterSize.has_value())
1135  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1136  if (dynamicSharedMemorySize)
1137  result.addOperands(dynamicSharedMemorySize);
1138  result.addOperands(kernelOperands);
1139  if (asyncObject)
1140  result.addOperands(asyncObject);
1141  Properties &prop = result.getOrAddProperties<Properties>();
1142  prop.kernel = kernel;
1143  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1144  // Initialize the segment sizes to 1.
1145  for (auto &sz : prop.operandSegmentSizes)
1146  sz = 1;
1147  prop.operandSegmentSizes[0] = 0;
1148  if (!clusterSize.has_value()) {
1149  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1150  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1151  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1152  }
1153  prop.operandSegmentSizes[segmentSizesLen - 3] =
1154  dynamicSharedMemorySize ? 1 : 0;
1155  prop.operandSegmentSizes[segmentSizesLen - 2] =
1156  static_cast<int32_t>(kernelOperands.size());
1157  prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1158 }
1159 
1160 StringAttr LaunchFuncOp::getKernelModuleName() {
1161  return getKernel().getRootReference();
1162 }
1163 
1164 StringAttr LaunchFuncOp::getKernelName() {
1165  return getKernel().getLeafReference();
1166 }
1167 
1168 unsigned LaunchFuncOp::getNumKernelOperands() {
1169  return getKernelOperands().size();
1170 }
1171 
1172 Value LaunchFuncOp::getKernelOperand(unsigned i) {
1173  return getKernelOperands()[i];
1174 }
1175 
1176 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1177  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1178  return KernelDim3{operands[0], operands[1], operands[2]};
1179 }
1180 
1181 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1182  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1183  return KernelDim3{operands[3], operands[4], operands[5]};
1184 }
1185 
1186 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1187  assert(hasClusterSize() &&
1188  "cluster size is not set, check hasClusterSize() first");
1189  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1190  return KernelDim3{operands[6], operands[7], operands[8]};
1191 }
1192 
1194  auto module = (*this)->getParentOfType<ModuleOp>();
1195  if (!module)
1196  return emitOpError("expected to belong to a module");
1197 
1198  if (!module->getAttrOfType<UnitAttr>(
1199  GPUDialect::getContainerModuleAttrName()))
1200  return emitOpError("expected the closest surrounding module to have the '" +
1201  GPUDialect::getContainerModuleAttrName() +
1202  "' attribute");
1203 
1204  if (hasClusterSize()) {
1205  if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1206  getClusterSizeZ().getType() != getClusterSizeX().getType())
1207  return emitOpError()
1208  << "expects types of the cluster dimensions must be the same";
1209  }
1210 
1211  return success();
1212 }
1213 
1214 static ParseResult
1216  std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1217  Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1218  if (succeeded(parser.parseOptionalColon())) {
1219  if (parser.parseType(dimTy))
1220  return failure();
1221  } else {
1222  dimTy = IndexType::get(parser.getContext());
1223  }
1224  if (clusterValue.has_value()) {
1225  clusterXTy = clusterYTy = clusterZTy = dimTy;
1226  }
1227  return success();
1228 }
1229 
1230 static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1231  Value clusterValue, Type clusterXTy,
1232  Type clusterYTy, Type clusterZTy) {
1233  if (!dimTy.isIndex())
1234  printer << ": " << dimTy;
1235 }
1236 
1238  OpAsmParser &parser,
1240  SmallVectorImpl<Type> &argTypes) {
1241  if (parser.parseOptionalKeyword("args"))
1242  return success();
1243 
1244  auto parseElement = [&]() -> ParseResult {
1245  return failure(parser.parseOperand(argNames.emplace_back()) ||
1246  parser.parseColonType(argTypes.emplace_back()));
1247  };
1248 
1250  parseElement, " in argument list");
1251 }
1252 
1254  OperandRange operands, TypeRange types) {
1255  if (operands.empty())
1256  return;
1257  printer << "args(";
1258  llvm::interleaveComma(llvm::zip(operands, types), printer,
1259  [&](const auto &pair) {
1260  printer.printOperand(std::get<0>(pair));
1261  printer << " : ";
1262  printer.printType(std::get<1>(pair));
1263  });
1264  printer << ")";
1265 }
1266 
1267 //===----------------------------------------------------------------------===//
1268 // ShuffleOp
1269 //===----------------------------------------------------------------------===//
1270 
1271 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1272  int32_t offset, int32_t width, ShuffleMode mode) {
1273  build(builder, result, value,
1274  builder.create<arith::ConstantOp>(result.location,
1275  builder.getI32IntegerAttr(offset)),
1276  builder.create<arith::ConstantOp>(result.location,
1277  builder.getI32IntegerAttr(width)),
1278  mode);
1279 }
1280 
1281 //===----------------------------------------------------------------------===//
1282 // BarrierOp
1283 //===----------------------------------------------------------------------===//
1284 
1285 namespace {
1286 
1287 /// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1288 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1289  PatternRewriter &rewriter) {
1290  if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1291  rewriter.eraseOp(op);
1292  return success();
1293  }
1294  return failure();
1295 }
1296 
1297 } // end anonymous namespace
1298 
1299 void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1300  MLIRContext *context) {
1301  results.add(eraseRedundantGpuBarrierOps);
1302 }
1303 
1304 //===----------------------------------------------------------------------===//
1305 // GPUFuncOp
1306 //===----------------------------------------------------------------------===//
1307 
1308 /// Adds a new block argument that corresponds to buffers located in
1309 /// workgroup memory.
1310 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1311  auto attrName = getNumWorkgroupAttributionsAttrName();
1312  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1313  (*this)->setAttr(attrName,
1314  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1315  return getBody().insertArgument(
1316  getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1317 }
1318 
1319 /// Adds a new block argument that corresponds to buffers located in
1320 /// private memory.
1321 BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1322  // Buffers on the private memory always come after buffers on the workgroup
1323  // memory.
1324  return getBody().addArgument(type, loc);
1325 }
1326 
1327 void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1328  StringRef name, FunctionType type,
1329  TypeRange workgroupAttributions,
1330  TypeRange privateAttributions,
1331  ArrayRef<NamedAttribute> attrs) {
1332  OpBuilder::InsertionGuard g(builder);
1333 
1335  builder.getStringAttr(name));
1336  result.addAttribute(getFunctionTypeAttrName(result.name),
1337  TypeAttr::get(type));
1338  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1339  builder.getI64IntegerAttr(workgroupAttributions.size()));
1340  result.addAttributes(attrs);
1341  Region *body = result.addRegion();
1342  Block *entryBlock = builder.createBlock(body);
1343 
1344  // TODO: Allow passing in proper locations here.
1345  for (Type argTy : type.getInputs())
1346  entryBlock->addArgument(argTy, result.location);
1347  for (Type argTy : workgroupAttributions)
1348  entryBlock->addArgument(argTy, result.location);
1349  for (Type argTy : privateAttributions)
1350  entryBlock->addArgument(argTy, result.location);
1351 }
1352 
1353 /// Parses a GPU function memory attribution.
1354 ///
1355 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1356 /// (`private` `(` ssa-id-and-type-list `)`)?
1357 ///
1358 /// Note that this function parses only one of the two similar parts, with the
1359 /// keyword provided as argument.
1360 static ParseResult
1361 parseAttributions(OpAsmParser &parser, StringRef keyword,
1363  Attribute &attributionAttrs) {
1364  // If we could not parse the keyword, just assume empty list and succeed.
1365  if (failed(parser.parseOptionalKeyword(keyword)))
1366  return success();
1367 
1368  size_t existingArgs = args.size();
1369  ParseResult result =
1371  /*allowType=*/true, /*allowAttrs=*/true);
1372  if (failed(result))
1373  return result;
1374 
1375  bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1376  [](const OpAsmParser::Argument &arg) -> bool {
1377  return arg.attrs && !arg.attrs.empty();
1378  });
1379  if (!hadAttrs) {
1380  attributionAttrs = nullptr;
1381  return result;
1382  }
1383 
1384  Builder &builder = parser.getBuilder();
1385  SmallVector<Attribute> attributionAttrsVec;
1386  for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1387  if (!argument.attrs)
1388  attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1389  else
1390  attributionAttrsVec.push_back(argument.attrs);
1391  }
1392  attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1393  return result;
1394 }
1395 
1396 /// Parses a GPU function.
1397 ///
1398 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1399 /// (`->` function-result-list)? memory-attribution `kernel`?
1400 /// function-attributes? region
1403  SmallVector<DictionaryAttr> resultAttrs;
1404  SmallVector<Type> resultTypes;
1405  bool isVariadic;
1406 
1407  // Parse the function name.
1408  StringAttr nameAttr;
1409  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1410  result.attributes))
1411  return failure();
1412 
1413  auto signatureLocation = parser.getCurrentLocation();
1415  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1416  resultAttrs)))
1417  return failure();
1418 
1419  if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1420  return parser.emitError(signatureLocation)
1421  << "gpu.func requires named arguments";
1422 
1423  // Construct the function type. More types will be added to the region, but
1424  // not to the function type.
1425  Builder &builder = parser.getBuilder();
1426 
1427  SmallVector<Type> argTypes;
1428  for (auto &arg : entryArgs)
1429  argTypes.push_back(arg.type);
1430  auto type = builder.getFunctionType(argTypes, resultTypes);
1431  result.addAttribute(getFunctionTypeAttrName(result.name),
1432  TypeAttr::get(type));
1433 
1435  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1436  getResAttrsAttrName(result.name));
1437 
1438  Attribute workgroupAttributionAttrs;
1439  // Parse workgroup memory attributions.
1440  if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1441  entryArgs, workgroupAttributionAttrs)))
1442  return failure();
1443 
1444  // Store the number of operands we just parsed as the number of workgroup
1445  // memory attributions.
1446  unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1447  result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1448  builder.getI64IntegerAttr(numWorkgroupAttrs));
1449  if (workgroupAttributionAttrs)
1450  result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1451  workgroupAttributionAttrs);
1452 
1453  Attribute privateAttributionAttrs;
1454  // Parse private memory attributions.
1455  if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1456  entryArgs, privateAttributionAttrs)))
1457  return failure();
1458  if (privateAttributionAttrs)
1459  result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1460  privateAttributionAttrs);
1461 
1462  // Parse the kernel attribute if present.
1463  if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1464  result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1465  builder.getUnitAttr());
1466 
1467  // Parse attributes.
1469  return failure();
1470 
1471  // Parse the region. If no argument names were provided, take all names
1472  // (including those of attributions) from the entry block.
1473  auto *body = result.addRegion();
1474  return parser.parseRegion(*body, entryArgs);
1475 }
1476 
1477 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
1478  ArrayRef<BlockArgument> values,
1479  ArrayAttr attributes) {
1480  if (values.empty())
1481  return;
1482 
1483  p << ' ' << keyword << '(';
1484  llvm::interleaveComma(
1485  llvm::enumerate(values), p, [&p, attributes](auto pair) {
1486  BlockArgument v = pair.value();
1487  p << v << " : " << v.getType();
1488 
1489  size_t attributionIndex = pair.index();
1490  DictionaryAttr attrs;
1491  if (attributes && attributionIndex < attributes.size())
1492  attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1493  if (attrs)
1494  p.printOptionalAttrDict(attrs.getValue());
1495  });
1496  p << ')';
1497 }
1498 
1499 void GPUFuncOp::print(OpAsmPrinter &p) {
1500  p << ' ';
1501  p.printSymbolName(getName());
1502 
1503  FunctionType type = getFunctionType();
1504  function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1505  /*isVariadic=*/false,
1506  type.getResults());
1507 
1508  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1509  getWorkgroupAttribAttrs().value_or(nullptr));
1510  printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1511  getPrivateAttribAttrs().value_or(nullptr));
1512  if (isKernel())
1513  p << ' ' << getKernelKeyword();
1514 
1516  p, *this,
1517  {getNumWorkgroupAttributionsAttrName(),
1518  GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1519  getArgAttrsAttrName(), getResAttrsAttrName(),
1520  getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1521  p << ' ';
1522  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1523 }
1524 
1525 static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1526  StringAttr attrName) {
1527  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1528  if (!allAttrs || index >= allAttrs.size())
1529  return DictionaryAttr();
1530  return llvm::cast<DictionaryAttr>(allAttrs[index]);
1531 }
1532 
1533 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1534  return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1535 }
1536 
1537 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1538  return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1539 }
1540 
1541 static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1542  DictionaryAttr value, StringAttr attrName) {
1543  MLIRContext *ctx = op.getContext();
1544  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1545  SmallVector<Attribute> elements;
1546  if (allAttrs)
1547  elements.append(allAttrs.begin(), allAttrs.end());
1548  while (elements.size() <= index)
1549  elements.push_back(DictionaryAttr::get(ctx));
1550  if (!value)
1551  elements[index] = DictionaryAttr::get(ctx);
1552  else
1553  elements[index] = value;
1554  ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1555  op->setAttr(attrName, newValue);
1556 }
1557 
1558 void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1559  DictionaryAttr value) {
1560  setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1561 }
1562 
1563 void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1564  DictionaryAttr value) {
1565  setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1566 }
1567 
1568 static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1569  StringAttr name, StringAttr attrsName) {
1570  DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1571  if (!dict)
1572  return Attribute();
1573  return dict.get(name);
1574 }
1575 
1576 Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1577  StringAttr name) {
1578  assert(index < getNumWorkgroupAttributions() &&
1579  "index must map to a workgroup attribution");
1580  return getAttributionAttr(*this, index, name,
1581  getWorkgroupAttribAttrsAttrName());
1582 }
1583 
1584 Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1585  StringAttr name) {
1586  assert(index < getNumPrivateAttributions() &&
1587  "index must map to a private attribution");
1588  return getAttributionAttr(*this, index, name,
1589  getPrivateAttribAttrsAttrName());
1590 }
1591 
1592 static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1593  Attribute value, StringAttr attrsName) {
1594  MLIRContext *ctx = op.getContext();
1596  DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1597  if (oldDict)
1598  elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1599 
1600  bool found = false;
1601  bool mustSort = true;
1602  for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1603  if (elems[i].getName() == name) {
1604  found = true;
1605  if (!value) {
1606  std::swap(elems[i], elems[elems.size() - 1]);
1607  elems.pop_back();
1608  } else {
1609  mustSort = false;
1610  elems[i] = NamedAttribute(elems[i].getName(), value);
1611  }
1612  break;
1613  }
1614  }
1615  if (!found) {
1616  if (!value)
1617  return;
1618  elems.emplace_back(name, value);
1619  }
1620  if (mustSort) {
1621  DictionaryAttr::sortInPlace(elems);
1622  }
1623  auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1624  setAttributionAttrs(op, index, newDict, attrsName);
1625 }
1626 
1627 void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1628  Attribute value) {
1629  assert(index < getNumWorkgroupAttributions() &&
1630  "index must map to a workgroup attribution");
1631  setAttributionAttr(*this, index, name, value,
1632  getWorkgroupAttribAttrsAttrName());
1633 }
1634 
1635 void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1636  Attribute value) {
1637  assert(index < getNumPrivateAttributions() &&
1638  "index must map to a private attribution");
1639  setAttributionAttr(*this, index, name, value,
1640  getPrivateAttribAttrsAttrName());
1641 }
1642 
1643 LogicalResult GPUFuncOp::verifyType() {
1644  if (isKernel() && getFunctionType().getNumResults() != 0)
1645  return emitOpError() << "expected void return type for kernel function";
1646 
1647  return success();
1648 }
1649 
1650 /// Verifies the body of the function.
1651 LogicalResult GPUFuncOp::verifyBody() {
1652  if (empty())
1653  return emitOpError() << "expected body with at least one block";
1654  unsigned numFuncArguments = getNumArguments();
1655  unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1656  unsigned numBlockArguments = front().getNumArguments();
1657  if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1658  return emitOpError() << "expected at least "
1659  << numFuncArguments + numWorkgroupAttributions
1660  << " arguments to body region";
1661 
1662  ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1663  for (unsigned i = 0; i < numFuncArguments; ++i) {
1664  Type blockArgType = front().getArgument(i).getType();
1665  if (funcArgTypes[i] != blockArgType)
1666  return emitOpError() << "expected body region argument #" << i
1667  << " to be of type " << funcArgTypes[i] << ", got "
1668  << blockArgType;
1669  }
1670 
1671  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1672  GPUDialect::getWorkgroupAddressSpace())) ||
1673  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1674  GPUDialect::getPrivateAddressSpace())))
1675  return failure();
1676 
1677  return success();
1678 }
1679 
1680 static LogicalResult verifyKnownLaunchSizeAttr(gpu::GPUFuncOp op,
1681  StringRef attrName) {
1682  auto maybeAttr = op->getAttr(attrName);
1683  if (!maybeAttr)
1684  return success();
1685  auto array = llvm::dyn_cast<DenseI32ArrayAttr>(maybeAttr);
1686  if (!array)
1687  return op.emitOpError(attrName + " must be a dense i32 array");
1688  if (array.size() != 3)
1689  return op.emitOpError(attrName + " must contain exactly 3 elements");
1690  return success();
1691 }
1692 
1694  if (failed(verifyKnownLaunchSizeAttr(*this, getKnownBlockSizeAttrName())))
1695  return failure();
1696  if (failed(verifyKnownLaunchSizeAttr(*this, getKnownGridSizeAttrName())))
1697  return failure();
1698  return success();
1699 }
1700 
1701 //===----------------------------------------------------------------------===//
1702 // ReturnOp
1703 //===----------------------------------------------------------------------===//
1704 
1706  GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1707 
1708  FunctionType funType = function.getFunctionType();
1709 
1710  if (funType.getNumResults() != getOperands().size())
1711  return emitOpError()
1712  .append("expected ", funType.getNumResults(), " result operands")
1713  .attachNote(function.getLoc())
1714  .append("return type declared here");
1715 
1716  for (const auto &pair : llvm::enumerate(
1717  llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1718  auto [type, operand] = pair.value();
1719  if (type != operand.getType())
1720  return emitOpError() << "unexpected type `" << operand.getType()
1721  << "' for operand #" << pair.index();
1722  }
1723  return success();
1724 }
1725 
1726 //===----------------------------------------------------------------------===//
1727 // GPUModuleOp
1728 //===----------------------------------------------------------------------===//
1729 
1730 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1731  StringRef name, ArrayAttr targets,
1732  Attribute offloadingHandler) {
1733  ensureTerminator(*result.addRegion(), builder, result.location);
1734  result.attributes.push_back(builder.getNamedAttr(
1736 
1737  Properties &props = result.getOrAddProperties<Properties>();
1738  if (targets)
1739  props.targets = targets;
1740  props.offloadingHandler = offloadingHandler;
1741 }
1742 
1743 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1744  StringRef name, ArrayRef<Attribute> targets,
1745  Attribute offloadingHandler) {
1746  build(builder, result, name,
1747  targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1748  offloadingHandler);
1749 }
1750 
1752  StringAttr nameAttr;
1753  ArrayAttr targetsAttr;
1754 
1756  result.attributes))
1757  return failure();
1758 
1759  Properties &props = result.getOrAddProperties<Properties>();
1760 
1761  // Parse the optional offloadingHandler
1762  if (succeeded(parser.parseOptionalLess())) {
1763  if (parser.parseAttribute(props.offloadingHandler))
1764  return failure();
1765  if (parser.parseGreater())
1766  return failure();
1767  }
1768 
1769  // Parse the optional array of target attributes.
1770  OptionalParseResult targetsAttrResult =
1771  parser.parseOptionalAttribute(targetsAttr, Type{});
1772  if (targetsAttrResult.has_value()) {
1773  if (failed(*targetsAttrResult)) {
1774  return failure();
1775  }
1776  props.targets = targetsAttr;
1777  }
1778 
1779  // If module attributes are present, parse them.
1780  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1781  return failure();
1782 
1783  // Parse the module body.
1784  auto *body = result.addRegion();
1785  if (parser.parseRegion(*body, {}))
1786  return failure();
1787 
1788  // Ensure that this module has a valid terminator.
1789  GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
1790  return success();
1791 }
1792 
1794  p << ' ';
1795  p.printSymbolName(getName());
1796 
1797  if (Attribute attr = getOffloadingHandlerAttr()) {
1798  p << " <";
1799  p.printAttribute(attr);
1800  p << ">";
1801  }
1802 
1803  if (Attribute attr = getTargetsAttr()) {
1804  p << ' ';
1805  p.printAttribute(attr);
1806  p << ' ';
1807  }
1808 
1809  p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
1810  {mlir::SymbolTable::getSymbolAttrName(),
1811  getTargetsAttrName(),
1812  getOffloadingHandlerAttrName()});
1813  p << ' ';
1814  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1815  /*printBlockTerminators=*/false);
1816 }
1817 
1818 bool GPUModuleOp::hasTarget(Attribute target) {
1819  if (ArrayAttr targets = getTargetsAttr())
1820  return llvm::count(targets.getValue(), target);
1821  return false;
1822 }
1823 
1824 void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1825  ArrayAttr &targetsAttr = getProperties().targets;
1826  SmallVector<Attribute> targetsVector(targets);
1827  targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1828 }
1829 
1830 //===----------------------------------------------------------------------===//
1831 // GPUBinaryOp
1832 //===----------------------------------------------------------------------===//
1833 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1834  Attribute offloadingHandler, ArrayAttr objects) {
1835  auto &properties = result.getOrAddProperties<Properties>();
1836  result.attributes.push_back(builder.getNamedAttr(
1837  SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1838  properties.objects = objects;
1839  if (offloadingHandler)
1840  properties.offloadingHandler = offloadingHandler;
1841  else
1842  properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
1843 }
1844 
1845 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1846  Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1847  build(builder, result, name, offloadingHandler,
1848  objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
1849 }
1850 
1852  Attribute &offloadingHandler) {
1853  if (succeeded(parser.parseOptionalLess())) {
1854  if (parser.parseAttribute(offloadingHandler))
1855  return failure();
1856  if (parser.parseGreater())
1857  return failure();
1858  }
1859  if (!offloadingHandler)
1860  offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
1861  return success();
1862 }
1863 
1865  Attribute offloadingHandler) {
1866  if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
1867  printer << '<' << offloadingHandler << '>';
1868 }
1869 
1870 //===----------------------------------------------------------------------===//
1871 // GPUMemcpyOp
1872 //===----------------------------------------------------------------------===//
1873 
1874 LogicalResult MemcpyOp::verify() {
1875  auto srcType = getSrc().getType();
1876  auto dstType = getDst().getType();
1877 
1878  if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1879  return emitOpError("arguments have incompatible element type");
1880 
1881  if (failed(verifyCompatibleShape(srcType, dstType)))
1882  return emitOpError("arguments have incompatible shape");
1883 
1884  return success();
1885 }
1886 
1887 namespace {
1888 
1889 /// Erases a common case of copy ops where a destination value is used only by
1890 /// the copy op, alloc and dealloc ops.
1891 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1892  using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1893 
1894  LogicalResult matchAndRewrite(MemcpyOp op,
1895  PatternRewriter &rewriter) const override {
1896  Value dest = op.getDst();
1897  Operation *destDefOp = dest.getDefiningOp();
1898  // `dest` must be defined by an op having Allocate memory effect in order to
1899  // perform the folding.
1900  if (!destDefOp ||
1901  !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1902  return failure();
1903  // We can erase `op` iff `dest` has no other use apart from its
1904  // use by `op` and dealloc ops.
1905  if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1906  return user != op &&
1907  !hasSingleEffect<MemoryEffects::Free>(user, dest);
1908  }))
1909  return failure();
1910  // We can perform the folding if and only if op has a single async
1911  // dependency and produces an async token as result, or if it does not have
1912  // any async dependency and does not produce any async token result.
1913  if (op.getAsyncDependencies().size() > 1 ||
1914  ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1915  (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1916  return failure();
1917  rewriter.replaceOp(op, op.getAsyncDependencies());
1918  return success();
1919  }
1920 };
1921 
1922 } // end anonymous namespace
1923 
1924 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1925  MLIRContext *context) {
1926  results.add<EraseTrivialCopyOp>(context);
1927 }
1928 
1929 //===----------------------------------------------------------------------===//
1930 // GPU_SubgroupMmaLoadMatrixOp
1931 //===----------------------------------------------------------------------===//
1932 
1933 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1934  auto srcType = getSrcMemref().getType();
1935  auto resType = getRes().getType();
1936  auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
1937  auto operand = resMatrixType.getOperand();
1938  auto srcMemrefType = llvm::cast<MemRefType>(srcType);
1939 
1940  if (!isLastMemrefDimUnitStride(srcMemrefType))
1941  return emitError(
1942  "expected source memref most minor dim must have unit stride");
1943 
1944  if (!operand.equals("AOp") && !operand.equals("BOp") &&
1945  !operand.equals("COp"))
1946  return emitError("only AOp, BOp and COp can be loaded");
1947 
1948  return success();
1949 }
1950 
1951 //===----------------------------------------------------------------------===//
1952 // GPU_SubgroupMmaStoreMatrixOp
1953 //===----------------------------------------------------------------------===//
1954 
1955 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1956  auto srcType = getSrc().getType();
1957  auto dstType = getDstMemref().getType();
1958  auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
1959  auto dstMemrefType = llvm::cast<MemRefType>(dstType);
1960 
1961  if (!isLastMemrefDimUnitStride(dstMemrefType))
1962  return emitError(
1963  "expected destination memref most minor dim must have unit stride");
1964 
1965  if (!srcMatrixType.getOperand().equals("COp"))
1966  return emitError(
1967  "expected the operand matrix being stored to have 'COp' operand type");
1968 
1969  return success();
1970 }
1971 
1972 //===----------------------------------------------------------------------===//
1973 // GPU_SubgroupMmaComputeOp
1974 //===----------------------------------------------------------------------===//
1975 
1976 LogicalResult SubgroupMmaComputeOp::verify() {
1977  enum OperandMap { A, B, C };
1978  SmallVector<MMAMatrixType, 3> opTypes;
1979  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
1980  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
1981  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
1982 
1983  if (!opTypes[A].getOperand().equals("AOp") ||
1984  !opTypes[B].getOperand().equals("BOp") ||
1985  !opTypes[C].getOperand().equals("COp"))
1986  return emitError("operands must be in the order AOp, BOp, COp");
1987 
1988  ArrayRef<int64_t> aShape, bShape, cShape;
1989  aShape = opTypes[A].getShape();
1990  bShape = opTypes[B].getShape();
1991  cShape = opTypes[C].getShape();
1992 
1993  if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1994  bShape[1] != cShape[1])
1995  return emitError("operand shapes do not satisfy matmul constraints");
1996 
1997  return success();
1998 }
1999 
2000 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2001  SmallVectorImpl<::mlir::OpFoldResult> &results) {
2002  return memref::foldMemRefCast(*this);
2003 }
2004 
2005 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2006  SmallVectorImpl<::mlir::OpFoldResult> &results) {
2007  return memref::foldMemRefCast(*this);
2008 }
2009 
2010 //===----------------------------------------------------------------------===//
2011 // GPU_WaitOp
2012 //===----------------------------------------------------------------------===//
2013 
2014 namespace {
2015 
2016 /// Remove gpu.wait op use of gpu.wait op def without async dependencies.
2017 /// %t = gpu.wait async [] // No async dependencies.
2018 /// ... gpu.wait ... [%t, ...] // %t can be removed.
2019 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2020 public:
2021  using OpRewritePattern::OpRewritePattern;
2022 
2023  LogicalResult matchAndRewrite(WaitOp op,
2024  PatternRewriter &rewriter) const final {
2025  auto predicate = [](Value value) {
2026  auto waitOp = value.getDefiningOp<WaitOp>();
2027  return waitOp && waitOp->getNumOperands() == 0;
2028  };
2029  if (llvm::none_of(op.getAsyncDependencies(), predicate))
2030  return failure();
2031  SmallVector<Value> validOperands;
2032  for (Value operand : op->getOperands()) {
2033  if (predicate(operand))
2034  continue;
2035  validOperands.push_back(operand);
2036  }
2037  rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2038  return success();
2039  }
2040 };
2041 
2042 /// Simplify trivial gpu.wait ops for the following patterns.
2043 /// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2044 /// dependencies).
2045 /// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2046 /// %t0.
2047 /// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2048 /// dependencies nor return any token.
2049 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2050 public:
2051  using OpRewritePattern::OpRewritePattern;
2052 
2053  LogicalResult matchAndRewrite(WaitOp op,
2054  PatternRewriter &rewriter) const final {
2055  // Erase gpu.wait ops that neither have any async dependencies nor return
2056  // any async token.
2057  if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2058  rewriter.eraseOp(op);
2059  return success();
2060  }
2061  // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2062  if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2063  op.getAsyncToken()) {
2064  rewriter.replaceOp(op, op.getAsyncDependencies());
2065  return success();
2066  }
2067  // Erase %t = gpu.wait async ... ops, where %t has no uses.
2068  if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2069  rewriter.eraseOp(op);
2070  return success();
2071  }
2072  return failure();
2073  }
2074 };
2075 
2076 } // end anonymous namespace
2077 
2078 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2079  MLIRContext *context) {
2080  results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2081 }
2082 
2083 //===----------------------------------------------------------------------===//
2084 // GPU_AllocOp
2085 //===----------------------------------------------------------------------===//
2086 
2087 LogicalResult AllocOp::verify() {
2088  auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2089 
2090  if (static_cast<int64_t>(getDynamicSizes().size()) !=
2091  memRefType.getNumDynamicDims())
2092  return emitOpError("dimension operand count does not equal memref "
2093  "dynamic dimension count");
2094 
2095  unsigned numSymbols = 0;
2096  if (!memRefType.getLayout().isIdentity())
2097  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2098  if (getSymbolOperands().size() != numSymbols) {
2099  return emitOpError(
2100  "symbol operand count does not equal memref symbol count");
2101  }
2102 
2103  return success();
2104 }
2105 
2106 namespace {
2107 
2108 /// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2109 /// `memref::AllocOp`.
2110 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2111  using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2112 
2113  LogicalResult matchAndRewrite(memref::DimOp dimOp,
2114  PatternRewriter &rewriter) const override {
2115  std::optional<int64_t> index = dimOp.getConstantIndex();
2116  if (!index)
2117  return failure();
2118 
2119  auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2120  if (!memrefType || !memrefType.isDynamicDim(index.value()))
2121  return failure();
2122 
2123  auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2124  if (!alloc)
2125  return failure();
2126 
2127  Value substituteOp = *(alloc.getDynamicSizes().begin() +
2128  memrefType.getDynamicDimIndex(index.value()));
2129  rewriter.replaceOp(dimOp, substituteOp);
2130  return success();
2131  }
2132 };
2133 
2134 } // namespace
2135 
2136 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2137  MLIRContext *context) {
2138  results.add<SimplifyDimOfAllocOp>(context);
2139 }
2140 
2141 //===----------------------------------------------------------------------===//
2142 // GPU object attribute
2143 //===----------------------------------------------------------------------===//
2144 
2145 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2146  Attribute target, CompilationTarget format,
2147  StringAttr object, DictionaryAttr properties) {
2148  if (!target)
2149  return emitError() << "the target attribute cannot be null";
2150  if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2151  return success();
2152  return emitError() << "the target attribute must implement or promise the "
2153  "`gpu::TargetAttrInterface`";
2154 }
2155 
2156 namespace {
2157 LogicalResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2158  StringAttr &object) {
2159  std::optional<CompilationTarget> formatResult;
2160  StringRef enumKeyword;
2161  auto loc = odsParser.getCurrentLocation();
2162  if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2163  formatResult = CompilationTarget::Fatbin;
2164  if (!formatResult &&
2165  (formatResult =
2166  gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2167  odsParser.parseEqual())
2168  return odsParser.emitError(loc, "expected an equal sign");
2169  if (!formatResult)
2170  return odsParser.emitError(loc, "expected keyword for GPU object format");
2171  FailureOr<StringAttr> objectResult =
2172  FieldParser<StringAttr>::parse(odsParser);
2173  if (failed(objectResult))
2174  return odsParser.emitError(odsParser.getCurrentLocation(),
2175  "failed to parse GPU_ObjectAttr parameter "
2176  "'object' which is to be a `StringAttr`");
2177  format = *formatResult;
2178  object = *objectResult;
2179  return success();
2180 }
2181 
2182 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2183  StringAttr object) {
2184  if (format != CompilationTarget::Fatbin)
2185  odsParser << stringifyEnum(format) << " = ";
2186  odsParser << object;
2187 }
2188 } // namespace
2189 
2190 //===----------------------------------------------------------------------===//
2191 // GPU select object attribute
2192 //===----------------------------------------------------------------------===//
2193 
2194 LogicalResult
2195 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2196  Attribute target) {
2197  // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2198  if (target) {
2199  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2200  if (intAttr.getInt() < 0) {
2201  return emitError() << "the object index must be positive";
2202  }
2203  } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2204  return emitError()
2205  << "the target attribute must be a GPU Target attribute";
2206  }
2207  }
2208  return success();
2209 }
2210 
2211 //===----------------------------------------------------------------------===//
2212 // DynamicSharedMemoryOp
2213 //===----------------------------------------------------------------------===//
2214 
2215 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2216  if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2217  return emitOpError() << "must be inside an op with symbol table";
2218 
2219  MemRefType memrefType = getResultMemref().getType();
2220  // Check address space
2221  if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2222  return emitOpError() << "address space must be "
2223  << gpu::AddressSpaceAttr::getMnemonic() << "<"
2224  << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2225  }
2226  if (memrefType.hasStaticShape()) {
2227  return emitOpError() << "result memref type must be memref<?xi8, "
2228  "#gpu.address_space<workgroup>>";
2229  }
2230  return success();
2231 }
2232 
2233 //===----------------------------------------------------------------------===//
2234 // GPU target options
2235 //===----------------------------------------------------------------------===//
2236 
2237 TargetOptions::TargetOptions(
2238  StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2239  StringRef cmdOptions, CompilationTarget compilationTarget,
2240  function_ref<SymbolTable *()> getSymbolTableCallback)
2241  : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
2242  cmdOptions, compilationTarget, getSymbolTableCallback) {}
2243 
2244 TargetOptions::TargetOptions(
2245  TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2246  StringRef cmdOptions, CompilationTarget compilationTarget,
2247  function_ref<SymbolTable *()> getSymbolTableCallback)
2248  : toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
2249  cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
2250  getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
2251 
2252 TypeID TargetOptions::getTypeID() const { return typeID; }
2253 
2254 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2255 
2256 ArrayRef<std::string> TargetOptions::getLinkFiles() const { return linkFiles; }
2257 
2258 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2259 
2260 SymbolTable *TargetOptions::getSymbolTable() const {
2261  return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2262 }
2263 
2264 CompilationTarget TargetOptions::getCompilationTarget() const {
2265  return compilationTarget;
2266 }
2267 
2268 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2269  return CompilationTarget::Fatbin;
2270 }
2271 
2272 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2273 TargetOptions::tokenizeCmdOptions() const {
2274  std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2275  llvm::StringSaver stringSaver(options.first);
2276  StringRef opts = cmdOptions;
2277  // For a correct tokenization of the command line options `opts` must be
2278  // unquoted, otherwise the tokenization function returns a single string: the
2279  // unquoted `cmdOptions` -which is not the desired behavior.
2280  // Remove any quotes if they are at the beginning and end of the string:
2281  if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2282  opts.consume_front("\""), opts.consume_back("\"");
2283  if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2284  opts.consume_front("'"), opts.consume_back("'");
2285 #ifdef _WIN32
2286  llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2287  /*MarkEOLs=*/false);
2288 #else
2289  llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2290  /*MarkEOLs=*/false);
2291 #endif // _WIN32
2292  return options;
2293 }
2294 
2296 
2297 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2298 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2299 
2300 #define GET_ATTRDEF_CLASSES
2301 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2302 
2303 #define GET_OP_CLASSES
2304 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2305 
2306 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
Definition: GPUDialect.cpp:407
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
Definition: GPUDialect.cpp:571
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
Definition: GPUDialect.cpp:423
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
Definition: GPUDialect.cpp:859
static LogicalResult verifyKnownLaunchSizeAttr(gpu::GPUFuncOp op, StringRef attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)
Prints a GPU function memory attribution.
Definition: GPUDialect.cpp:457
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
Definition: GPUDialect.cpp:549
static std::string getSparseHandleKeyword(SparseHandleKind kind)
Definition: GPUDialect.cpp:223
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
Definition: GPUDialect.cpp:584
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
Definition: GPUDialect.cpp:446
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
Definition: GPUDialect.cpp:495
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
Definition: GPUDialect.cpp:810
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
Definition: GPUDialect.cpp:469
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MINUI(lhs, rhs)
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:263
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:179
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:96
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:93
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
IndexType getIndexType()
Definition: Builders.cpp:71
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition: Builders.cpp:120
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:110
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:100
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:202
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:49
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:216
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
Definition: Operation.cpp:256
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:79
bool isIndex() const
Definition: Types.cpp:56
bool isF32() const
Definition: Types.cpp:51
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:91
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:58
bool isF16() const
Definition: Types.cpp:49
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static ConcreteT get(MLIRContext *ctx, Args &&...args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:130
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Definition: GPUDialect.cpp:137
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:122
Type getElementType() const
Get elementType of a single element.
Definition: GPUDialect.cpp:141
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
Definition: GPUDialect.cpp:145
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Definition: GPUDialect.cpp:143
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
Definition: GPUDialect.cpp:128
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
Definition: GPUDialect.cpp:152
unsigned getNumDims() const
Get number of dims.
Definition: GPUDialect.cpp:135
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
Definition: GPUDialect.cpp:625
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:38