MLIR  19.0.0git
GPUDialect.cpp
Go to the documentation of this file.
1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU kernel-related dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Diagnostics.h"
24 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/IR/SymbolTable.h"
28 #include "mlir/IR/TypeUtilities.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/ErrorHandling.h"
37 #include "llvm/Support/StringSaver.h"
38 #include <cassert>
39 
40 using namespace mlir;
41 using namespace mlir::gpu;
42 
43 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
44 
45 //===----------------------------------------------------------------------===//
46 // GPU Device Mapping Attributes
47 //===----------------------------------------------------------------------===//
48 
49 int64_t GPUBlockMappingAttr::getMappingId() const {
50  return static_cast<int64_t>(getBlock());
51 }
52 
53 bool GPUBlockMappingAttr::isLinearMapping() const {
54  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
55 }
56 
57 int64_t GPUBlockMappingAttr::getRelativeIndex() const {
58  return isLinearMapping()
59  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
60  : getMappingId();
61 }
62 
63 int64_t GPUWarpgroupMappingAttr::getMappingId() const {
64  return static_cast<int64_t>(getWarpgroup());
65 }
66 
67 bool GPUWarpgroupMappingAttr::isLinearMapping() const {
68  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
69 }
70 
71 int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
72  return isLinearMapping()
73  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
74  : getMappingId();
75 }
76 
77 int64_t GPUWarpMappingAttr::getMappingId() const {
78  return static_cast<int64_t>(getWarp());
79 }
80 
81 bool GPUWarpMappingAttr::isLinearMapping() const {
82  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
83 }
84 
85 int64_t GPUWarpMappingAttr::getRelativeIndex() const {
86  return isLinearMapping()
87  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
88  : getMappingId();
89 }
90 
91 int64_t GPUThreadMappingAttr::getMappingId() const {
92  return static_cast<int64_t>(getThread());
93 }
94 
95 bool GPUThreadMappingAttr::isLinearMapping() const {
96  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
97 }
98 
99 int64_t GPUThreadMappingAttr::getRelativeIndex() const {
100  return isLinearMapping()
101  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
102  : getMappingId();
103 }
104 
105 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
106  return static_cast<int64_t>(getAddressSpace());
107 }
108 
109 bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
110  llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
111 }
112 
113 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
114  llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
115 }
116 
117 //===----------------------------------------------------------------------===//
118 // MMAMatrixType
119 //===----------------------------------------------------------------------===//
120 
122  StringRef operand) {
123  return Base::get(elementType.getContext(), shape, elementType, operand);
124 }
125 
128  ArrayRef<int64_t> shape, Type elementType,
129  StringRef operand) {
130  return Base::getChecked(emitError, elementType.getContext(), shape,
131  elementType, operand);
132 }
133 
134 unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
135 
137  return getImpl()->getShape();
138 }
139 
140 Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
141 
142 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
143 
145  return elementType.isF16() || elementType.isF32() ||
146  elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
147  elementType.isInteger(32);
148 }
149 
152  ArrayRef<int64_t> shape, Type elementType,
153  StringRef operand) {
154  if (!operand.equals("AOp") && !operand.equals("BOp") &&
155  !operand.equals("COp"))
156  return emitError() << "operand expected to be one of AOp, BOp or COp";
157 
158  if (shape.size() != 2)
159  return emitError() << "MMAMatrixType must have exactly two dimensions";
160 
161  if (!MMAMatrixType::isValidElementType(elementType))
162  return emitError()
163  << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
164 
165  return success();
166 }
167 
168 //===----------------------------------------------------------------------===//
169 // GPUDialect
170 //===----------------------------------------------------------------------===//
171 
172 bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
173  if (!memorySpace)
174  return false;
175  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
176  return gpuAttr.getValue() == getWorkgroupAddressSpace();
177  return false;
178 }
179 
180 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
181  Attribute memorySpace = type.getMemorySpace();
182  return isWorkgroupMemoryAddressSpace(memorySpace);
183 }
184 
185 bool GPUDialect::isKernel(Operation *op) {
186  UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
187  return static_cast<bool>(isKernelAttr);
188 }
189 
190 namespace {
191 /// This class defines the interface for handling inlining with gpu
192 /// operations.
193 struct GPUInlinerInterface : public DialectInlinerInterface {
195 
196  /// All gpu dialect ops can be inlined.
197  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
198  return true;
199  }
200 };
201 } // namespace
202 
203 void GPUDialect::initialize() {
204  addTypes<AsyncTokenType>();
205  addTypes<MMAMatrixType>();
206  addTypes<SparseDnTensorHandleType>();
207  addTypes<SparseSpMatHandleType>();
208  addTypes<SparseSpGEMMOpHandleType>();
209  addOperations<
210 #define GET_OP_LIST
211 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
212  >();
213  addAttributes<
214 #define GET_ATTRDEF_LIST
215 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
216  >();
217  addInterfaces<GPUInlinerInterface>();
218 }
219 
220 static std::string getSparseHandleKeyword(SparseHandleKind kind) {
221  switch (kind) {
223  return "sparse.dntensor_handle";
225  return "sparse.spmat_handle";
227  return "sparse.spgemmop_handle";
228  }
229  llvm_unreachable("unknown sparse handle kind");
230  return "";
231 }
232 
234  // Parse the main keyword for the type.
235  StringRef keyword;
236  if (parser.parseKeyword(&keyword))
237  return Type();
238  MLIRContext *context = getContext();
239 
240  // Handle 'async token' types.
241  if (keyword == "async.token")
242  return AsyncTokenType::get(context);
243 
244  if (keyword == "mma_matrix") {
245  SMLoc beginLoc = parser.getNameLoc();
246 
247  // Parse '<'.
248  if (parser.parseLess())
249  return nullptr;
250 
251  // Parse the size and elementType.
252  SmallVector<int64_t> shape;
253  Type elementType;
254  if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
255  parser.parseType(elementType))
256  return nullptr;
257 
258  // Parse ','
259  if (parser.parseComma())
260  return nullptr;
261 
262  // Parse operand.
263  std::string operand;
264  if (failed(parser.parseOptionalString(&operand)))
265  return nullptr;
266 
267  // Parse '>'.
268  if (parser.parseGreater())
269  return nullptr;
270 
272  parser.getEncodedSourceLoc(beginLoc)),
273  shape, elementType, operand);
274  }
275 
277  return SparseDnTensorHandleType::get(context);
279  return SparseSpMatHandleType::get(context);
281  return SparseSpGEMMOpHandleType::get(context);
282 
283  parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
284  return Type();
285 }
286 // TODO: print refined type here. Notice that should be corresponding to the
287 // parser
288 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
289  TypeSwitch<Type>(type)
290  .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
291  .Case<SparseDnTensorHandleType>([&](Type) {
293  })
294  .Case<SparseSpMatHandleType>(
296  .Case<SparseSpGEMMOpHandleType>([&](Type) {
298  })
299  .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
300  os << "mma_matrix<";
301  auto shape = fragTy.getShape();
302  for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
303  os << *dim << 'x';
304  os << shape.back() << 'x' << fragTy.getElementType();
305  os << ", \"" << fragTy.getOperand() << "\"" << '>';
306  })
307  .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
308 }
309 
310 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
311  NamedAttribute attr) {
312  if (!llvm::isa<UnitAttr>(attr.getValue()) ||
313  attr.getName() != getContainerModuleAttrName())
314  return success();
315 
316  auto module = dyn_cast<ModuleOp>(op);
317  if (!module)
318  return op->emitError("expected '")
319  << getContainerModuleAttrName() << "' attribute to be attached to '"
320  << ModuleOp::getOperationName() << '\'';
321 
322  auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
323  // Ignore launches that are nested more or less deep than functions in the
324  // module we are currently checking.
325  if (!launchOp->getParentOp() ||
326  launchOp->getParentOp()->getParentOp() != module)
327  return success();
328 
329  // Ignore launch ops with missing attributes here. The errors will be
330  // reported by the verifiers of those ops.
331  if (!launchOp->getAttrOfType<SymbolRefAttr>(
332  LaunchFuncOp::getKernelAttrName(launchOp->getName())))
333  return success();
334 
335  // Check that `launch_func` refers to a well-formed GPU kernel container.
336  StringAttr kernelContainerName = launchOp.getKernelModuleName();
337  Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
338  if (!kernelContainer)
339  return launchOp.emitOpError()
340  << "kernel container '" << kernelContainerName.getValue()
341  << "' is undefined";
342 
343  // If the container is a GPU binary op return success.
344  if (isa<BinaryOp>(kernelContainer))
345  return success();
346 
347  auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
348  if (!kernelModule)
349  return launchOp.emitOpError()
350  << "kernel module '" << kernelContainerName.getValue()
351  << "' is undefined";
352 
353  // Check that `launch_func` refers to a well-formed kernel function.
354  Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
355  if (!kernelFunc)
356  return launchOp.emitOpError("kernel function '")
357  << launchOp.getKernel() << "' is undefined";
358  auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
359  if (!kernelConvertedFunction) {
360  InFlightDiagnostic diag = launchOp.emitOpError()
361  << "referenced kernel '" << launchOp.getKernel()
362  << "' is not a function";
363  diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
364  return diag;
365  }
366 
367  if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
368  GPUDialect::getKernelFuncAttrName()))
369  return launchOp.emitOpError("kernel function is missing the '")
370  << GPUDialect::getKernelFuncAttrName() << "' attribute";
371 
372  // TODO: If the kernel isn't a GPU function (which happens during separate
373  // compilation), do not check type correspondence as it would require the
374  // verifier to be aware of the type conversion.
375  auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
376  if (!kernelGPUFunction)
377  return success();
378 
379  unsigned actualNumArguments = launchOp.getNumKernelOperands();
380  unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
381  if (expectedNumArguments != actualNumArguments)
382  return launchOp.emitOpError("got ")
383  << actualNumArguments << " kernel operands but expected "
384  << expectedNumArguments;
385 
386  auto functionType = kernelGPUFunction.getFunctionType();
387  for (unsigned i = 0; i < expectedNumArguments; ++i) {
388  if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
389  return launchOp.emitOpError("type of function argument ")
390  << i << " does not match";
391  }
392  }
393 
394  return success();
395  });
396 
397  return walkResult.wasInterrupted() ? failure() : success();
398 }
399 
400 /// Parses an optional list of async operands with an optional leading keyword.
401 /// (`async`)? (`[` ssa-id-list `]`)?
402 ///
403 /// This method is used by the tablegen assembly format for async ops as well.
405  OpAsmParser &parser, Type &asyncTokenType,
407  auto loc = parser.getCurrentLocation();
408  if (succeeded(parser.parseOptionalKeyword("async"))) {
409  if (parser.getNumResults() == 0)
410  return parser.emitError(loc, "needs to be named when marked 'async'");
411  asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
412  }
413  return parser.parseOperandList(asyncDependencies,
415 }
416 
417 /// Prints optional async dependencies with its leading keyword.
418 /// (`async`)? (`[` ssa-id-list `]`)?
419 // Used by the tablegen assembly format for several async ops.
421  Type asyncTokenType,
422  OperandRange asyncDependencies) {
423  if (asyncTokenType)
424  printer << "async";
425  if (asyncDependencies.empty())
426  return;
427  if (asyncTokenType)
428  printer << ' ';
429  printer << '[';
430  llvm::interleaveComma(asyncDependencies, printer);
431  printer << ']';
432 }
433 
434 // GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
435 /// Parses a GPU function memory attribution.
436 ///
437 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
438 /// (`private` `(` ssa-id-and-type-list `)`)?
439 ///
440 /// Note that this function parses only one of the two similar parts, with the
441 /// keyword provided as argument.
442 static ParseResult
443 parseAttributions(OpAsmParser &parser, StringRef keyword,
445  // If we could not parse the keyword, just assume empty list and succeed.
446  if (failed(parser.parseOptionalKeyword(keyword)))
447  return success();
448 
450  /*allowType=*/true);
451 }
452 
453 /// Prints a GPU function memory attribution.
454 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
455  ArrayRef<BlockArgument> values) {
456  if (values.empty())
457  return;
458 
459  p << ' ' << keyword << '(';
460  llvm::interleaveComma(
461  values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); });
462  p << ')';
463 }
464 
465 /// Verifies a GPU function memory attribution.
467  ArrayRef<BlockArgument> attributions,
468  gpu::AddressSpace memorySpace) {
469  for (Value v : attributions) {
470  auto type = llvm::dyn_cast<MemRefType>(v.getType());
471  if (!type)
472  return op->emitOpError() << "expected memref type in attribution";
473 
474  // We can only verify the address space if it hasn't already been lowered
475  // from the AddressSpaceAttr to a target-specific numeric value.
476  auto addressSpace =
477  llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
478  if (!addressSpace)
479  continue;
480  if (addressSpace.getValue() != memorySpace)
481  return op->emitOpError()
482  << "expected memory space " << stringifyAddressSpace(memorySpace)
483  << " in attribution";
484  }
485  return success();
486 }
487 
488 //===----------------------------------------------------------------------===//
489 // AllReduceOp
490 //===----------------------------------------------------------------------===//
491 
492 static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
493  Type resType) {
494  using Kind = gpu::AllReduceOperation;
495  if (llvm::is_contained(
496  {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
497  opName)) {
498  if (!isa<FloatType>(resType))
499  return failure();
500  }
501 
502  if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
503  Kind::AND, Kind::OR, Kind::XOR},
504  opName)) {
505  if (!isa<IntegerType>(resType))
506  return failure();
507  }
508 
509  return success();
510 }
511 
512 LogicalResult gpu::AllReduceOp::verifyRegions() {
513  if (getBody().empty() != getOp().has_value())
514  return emitError("expected either an op attribute or a non-empty body");
515  if (!getBody().empty()) {
516  if (getBody().getNumArguments() != 2)
517  return emitError("expected two region arguments");
518  for (auto argument : getBody().getArguments()) {
519  if (argument.getType() != getType())
520  return emitError("incorrect region argument type");
521  }
522  unsigned yieldCount = 0;
523  for (Block &block : getBody()) {
524  if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
525  if (yield.getNumOperands() != 1)
526  return emitError("expected one gpu.yield operand");
527  if (yield.getOperand(0).getType() != getType())
528  return emitError("incorrect gpu.yield type");
529  ++yieldCount;
530  }
531  }
532  if (yieldCount == 0)
533  return emitError("expected gpu.yield op in region");
534  } else {
535  gpu::AllReduceOperation opName = *getOp();
536  if (failed(verifyReduceOpAndType(opName, getType()))) {
537  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
538  << "` reduction operation is not compatible with type "
539  << getType();
540  }
541  }
542 
543  return success();
544 }
545 
547  auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
548  if (!launchOp)
549  return false;
550 
551  Region &body = launchOp.getBody();
552  assert(!body.empty() && "Invalid region");
553 
554  // Only convert ops in gpu::launch entry block for now.
555  return op->getBlock() == &body.front();
556 }
557 
558 OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
559  if (!getUniform() && canMakeGroupOpUniform(*this)) {
560  setUniform(true);
561  return getResult();
562  }
563 
564  return nullptr;
565 }
566 
567 // TODO: Support optional custom attributes (without dialect prefix).
569  AllReduceOperationAttr &attr) {
570  StringRef enumStr;
571  if (!parser.parseOptionalKeyword(&enumStr)) {
572  std::optional<AllReduceOperation> op =
573  gpu::symbolizeAllReduceOperation(enumStr);
574  if (!op)
575  return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
576  attr = AllReduceOperationAttr::get(parser.getContext(), *op);
577  }
578  return success();
579 }
580 
581 static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
582  AllReduceOperationAttr attr) {
583  if (attr)
584  attr.print(printer);
585 }
586 
587 //===----------------------------------------------------------------------===//
588 // SubgroupReduceOp
589 //===----------------------------------------------------------------------===//
590 
592  Type elemType = getType();
593  if (auto vecTy = dyn_cast<VectorType>(elemType)) {
594  if (vecTy.isScalable())
595  return emitOpError() << "is not compatible with scalable vector types";
596 
597  elemType = vecTy.getElementType();
598  }
599 
600  gpu::AllReduceOperation opName = getOp();
601  if (failed(verifyReduceOpAndType(opName, elemType))) {
602  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
603  << "` reduction operation is not compatible with type "
604  << getType();
605  }
606  return success();
607 }
608 
609 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
610  if (!getUniform() && canMakeGroupOpUniform(*this)) {
611  setUniform(true);
612  return getResult();
613  }
614 
615  return nullptr;
616 }
617 
618 //===----------------------------------------------------------------------===//
619 // AsyncOpInterface
620 //===----------------------------------------------------------------------===//
621 
623  op->insertOperands(0, {token});
624  if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
625  return;
626  auto attrName =
628  auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
629 
630  // Async dependencies is the only variadic operand.
631  if (!sizeAttr)
632  return;
633 
634  SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
635  ++sizes.front();
636  op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
637 }
638 
639 //===----------------------------------------------------------------------===//
640 // LaunchOp
641 //===----------------------------------------------------------------------===//
642 
643 void LaunchOp::build(OpBuilder &builder, OperationState &result,
644  Value gridSizeX, Value gridSizeY, Value gridSizeZ,
645  Value getBlockSizeX, Value getBlockSizeY,
646  Value getBlockSizeZ, Value dynamicSharedMemorySize,
647  Type asyncTokenType, ValueRange asyncDependencies,
648  TypeRange workgroupAttributions,
649  TypeRange privateAttributions, Value clusterSizeX,
650  Value clusterSizeY, Value clusterSizeZ) {
651  OpBuilder::InsertionGuard g(builder);
652 
653  // Add a WorkGroup attribution attribute. This attribute is required to
654  // identify private attributions in the list of block argguments.
655  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
656  builder.getI64IntegerAttr(workgroupAttributions.size()));
657 
658  // Add Op operands.
659  result.addOperands(asyncDependencies);
660  if (asyncTokenType)
661  result.types.push_back(builder.getType<AsyncTokenType>());
662 
663  // Add grid and block sizes as op operands, followed by the data operands.
664  result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
665  getBlockSizeY, getBlockSizeZ});
666  if (clusterSizeX)
667  result.addOperands(clusterSizeX);
668  if (clusterSizeY)
669  result.addOperands(clusterSizeY);
670  if (clusterSizeZ)
671  result.addOperands(clusterSizeZ);
672  if (dynamicSharedMemorySize)
673  result.addOperands(dynamicSharedMemorySize);
674 
675  // Create a kernel body region with kNumConfigRegionAttributes + N memory
676  // attributions, where the first kNumConfigRegionAttributes arguments have
677  // `index` type and the rest have the same types as the data operands.
678  Region *kernelRegion = result.addRegion();
679  Block *body = builder.createBlock(kernelRegion);
680  // TODO: Allow passing in proper locations here.
681  for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
682  body->addArgument(builder.getIndexType(), result.location);
683  // Add WorkGroup & Private attributions to the region arguments.
684  for (Type argTy : workgroupAttributions)
685  body->addArgument(argTy, result.location);
686  for (Type argTy : privateAttributions)
687  body->addArgument(argTy, result.location);
688  // Fill OperandSegmentSize Attribute.
689  SmallVector<int32_t, 11> segmentSizes(11, 1);
690  segmentSizes.front() = asyncDependencies.size();
691  segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
692  segmentSizes[7] = clusterSizeX ? 1 : 0;
693  segmentSizes[8] = clusterSizeY ? 1 : 0;
694  segmentSizes[9] = clusterSizeZ ? 1 : 0;
695  result.addAttribute(getOperandSegmentSizeAttr(),
696  builder.getDenseI32ArrayAttr(segmentSizes));
697 }
698 
699 KernelDim3 LaunchOp::getBlockIds() {
700  assert(!getBody().empty() && "LaunchOp body must not be empty.");
701  auto args = getBody().getArguments();
702  return KernelDim3{args[0], args[1], args[2]};
703 }
704 
705 KernelDim3 LaunchOp::getThreadIds() {
706  assert(!getBody().empty() && "LaunchOp body must not be empty.");
707  auto args = getBody().getArguments();
708  return KernelDim3{args[3], args[4], args[5]};
709 }
710 
711 KernelDim3 LaunchOp::getGridSize() {
712  assert(!getBody().empty() && "LaunchOp body must not be empty.");
713  auto args = getBody().getArguments();
714  return KernelDim3{args[6], args[7], args[8]};
715 }
716 
718  assert(!getBody().empty() && "LaunchOp body must not be empty.");
719  auto args = getBody().getArguments();
720  return KernelDim3{args[9], args[10], args[11]};
721 }
722 
723 std::optional<KernelDim3> LaunchOp::getClusterIds() {
724  assert(!getBody().empty() && "LaunchOp body must not be empty.");
725  if (!hasClusterSize())
726  return std::nullopt;
727  auto args = getBody().getArguments();
728  return KernelDim3{args[12], args[13], args[14]};
729 }
730 
731 std::optional<KernelDim3> LaunchOp::getClusterSize() {
732  assert(!getBody().empty() && "LaunchOp body must not be empty.");
733  if (!hasClusterSize())
734  return std::nullopt;
735  auto args = getBody().getArguments();
736  return KernelDim3{args[15], args[16], args[17]};
737 }
738 
739 KernelDim3 LaunchOp::getGridSizeOperandValues() {
740  auto operands = getOperands().drop_front(getAsyncDependencies().size());
741  return KernelDim3{operands[0], operands[1], operands[2]};
742 }
743 
744 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
745  auto operands = getOperands().drop_front(getAsyncDependencies().size());
746  return KernelDim3{operands[3], operands[4], operands[5]};
747 }
748 
749 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
750  auto operands = getOperands().drop_front(getAsyncDependencies().size());
751  if (!hasClusterSize())
752  return std::nullopt;
753  return KernelDim3{operands[6], operands[7], operands[8]};
754 }
755 
757  if (!(hasClusterSize()) &&
758  (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
759  return emitOpError() << "cluster size must be all present";
760  return success();
761 }
762 
763 LogicalResult LaunchOp::verifyRegions() {
764  // Kernel launch takes kNumConfigOperands leading operands for grid/block
765  // sizes and transforms them into kNumConfigRegionAttributes region arguments
766  // for block/thread identifiers and grid/block sizes.
767  if (!getBody().empty()) {
768  if (getBody().getNumArguments() <
769  kNumConfigRegionAttributes + getNumWorkgroupAttributions())
770  return emitOpError("unexpected number of region arguments");
771  }
772 
773  // Verify Attributions Address Spaces.
774  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
775  GPUDialect::getWorkgroupAddressSpace())) ||
776  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
777  GPUDialect::getPrivateAddressSpace())))
778  return failure();
779 
780  // Block terminators without successors are expected to exit the kernel region
781  // and must be `gpu.terminator`.
782  for (Block &block : getBody()) {
783  if (block.empty())
784  continue;
785  if (block.back().getNumSuccessors() != 0)
786  continue;
787  if (!isa<gpu::TerminatorOp>(&block.back())) {
788  return block.back()
789  .emitError()
790  .append("expected '", gpu::TerminatorOp::getOperationName(),
791  "' or a terminator with successors")
792  .attachNote(getLoc())
793  .append("in '", LaunchOp::getOperationName(), "' body region");
794  }
795  }
796 
797  if (getNumResults() == 0 && getAsyncToken())
798  return emitOpError("needs to be named when async keyword is specified");
799 
800  return success();
801 }
802 
803 // Pretty-print the kernel grid/block size assignment as
804 // (%iter-x, %iter-y, %iter-z) in
805 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
806 // where %size-* and %iter-* will correspond to the body region arguments.
808  KernelDim3 operands, KernelDim3 ids) {
809  p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
810  p << size.x << " = " << operands.x << ", ";
811  p << size.y << " = " << operands.y << ", ";
812  p << size.z << " = " << operands.z << ')';
813 }
814 
815 void LaunchOp::print(OpAsmPrinter &p) {
816  if (getAsyncToken()) {
817  p << " async";
818  if (!getAsyncDependencies().empty())
819  p << " [" << getAsyncDependencies() << ']';
820  }
821  // Print the launch configuration.
822  if (hasClusterSize()) {
823  p << ' ' << getClustersKeyword();
824  printSizeAssignment(p, getClusterSize().value(),
825  getClusterSizeOperandValues().value(),
826  getClusterIds().value());
827  }
828  p << ' ' << getBlocksKeyword();
829  printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
830  getBlockIds());
831  p << ' ' << getThreadsKeyword();
832  printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
833  getThreadIds());
834  if (getDynamicSharedMemorySize())
835  p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
836  << getDynamicSharedMemorySize();
837 
838  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
839  printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
840 
841  p << ' ';
842 
843  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
844  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
845  LaunchOp::getOperandSegmentSizeAttr(),
846  getNumWorkgroupAttributionsAttrName()});
847 }
848 
849 // Parse the size assignment blocks for blocks and threads. These have the form
850 // (%region_arg, %region_arg, %region_arg) in
851 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
852 // where %region_arg are percent-identifiers for the region arguments to be
853 // introduced further (SSA defs), and %operand are percent-identifiers for the
854 // SSA value uses.
855 static ParseResult
860  assert(indices.size() == 3 && "space for three indices expected");
863  /*allowResultNumber=*/false) ||
864  parser.parseKeyword("in") || parser.parseLParen())
865  return failure();
866  std::move(args.begin(), args.end(), indices.begin());
867 
868  for (int i = 0; i < 3; ++i) {
869  if (i != 0 && parser.parseComma())
870  return failure();
871  if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
872  parser.parseEqual() || parser.parseOperand(sizes[i]))
873  return failure();
874  }
875 
876  return parser.parseRParen();
877 }
878 
879 /// Parses a Launch operation.
880 /// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
881 /// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
882 /// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
883 /// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
884 /// memory-attribution
885 /// region attr-dict?
886 /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
888  // Sizes of the grid and block.
890  sizes(LaunchOp::kNumConfigOperands);
891 
892  // Actual (data) operands passed to the kernel.
894 
895  // Region arguments to be created.
897  LaunchOp::kNumConfigRegionAttributes);
898 
899  // Parse optional async dependencies.
901  Type asyncTokenType;
902  if (failed(
903  parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
904  parser.resolveOperands(asyncDependencies, asyncTokenType,
905  result.operands))
906  return failure();
907  if (parser.getNumResults() > 0)
908  result.types.push_back(asyncTokenType);
909 
910  bool hasCluster = false;
911  if (succeeded(
912  parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) {
913  hasCluster = true;
914  sizes.resize(9);
915  regionArgs.resize(18);
916  }
918  MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
919 
920  // Last three segment assigns the cluster size. In the region argument
921  // list, this is last 6 arguments.
922  if (hasCluster) {
923  if (parseSizeAssignment(parser, sizesRef.drop_front(6),
924  regionArgsRef.slice(15, 3),
925  regionArgsRef.slice(12, 3)))
926  return failure();
927  }
928  // Parse the size assignment segments: the first segment assigns grid sizes
929  // and defines values for block identifiers; the second segment assigns block
930  // sizes and defines values for thread identifiers. In the region argument
931  // list, identifiers precede sizes, and block-related values precede
932  // thread-related values.
933  if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
934  parseSizeAssignment(parser, sizesRef.take_front(3),
935  regionArgsRef.slice(6, 3),
936  regionArgsRef.slice(0, 3)) ||
937  parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
938  parseSizeAssignment(parser, sizesRef.drop_front(3),
939  regionArgsRef.slice(9, 3),
940  regionArgsRef.slice(3, 3)) ||
941  parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
942  result.operands))
943  return failure();
944 
945  OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
946  bool hasDynamicSharedMemorySize = false;
947  if (!parser.parseOptionalKeyword(
948  LaunchOp::getDynamicSharedMemorySizeKeyword())) {
949  hasDynamicSharedMemorySize = true;
950  if (parser.parseOperand(dynamicSharedMemorySize) ||
951  parser.resolveOperand(dynamicSharedMemorySize,
952  parser.getBuilder().getI32Type(),
953  result.operands))
954  return failure();
955  }
956 
957  // Create the region arguments, it has kNumConfigRegionAttributes arguments
958  // that correspond to block/thread identifiers and grid/block sizes, all
959  // having `index` type, a variadic number of WorkGroup Attributions and
960  // a variadic number of Private Attributions. The number of WorkGroup
961  // Attributions is stored in the attr with name:
962  // LaunchOp::getNumWorkgroupAttributionsAttrName().
963  Type index = parser.getBuilder().getIndexType();
965  LaunchOp::kNumConfigRegionAttributes + 6, index);
966 
967  SmallVector<OpAsmParser::Argument> regionArguments;
968  for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
970  arg.ssaName = std::get<0>(ssaValueAndType);
971  arg.type = std::get<1>(ssaValueAndType);
972  regionArguments.push_back(arg);
973  }
974 
975  Builder &builder = parser.getBuilder();
976  // Parse workgroup memory attributions.
977  if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
978  regionArguments)))
979  return failure();
980 
981  // Store the number of operands we just parsed as the number of workgroup
982  // memory attributions.
983  unsigned numWorkgroupAttrs = regionArguments.size() -
984  LaunchOp::kNumConfigRegionAttributes -
985  (hasCluster ? 6 : 0);
986  result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
987  builder.getI64IntegerAttr(numWorkgroupAttrs));
988 
989  // Parse private memory attributions.
990  if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
991  regionArguments)))
992  return failure();
993 
994  // Introduce the body region and parse it. The region has
995  // kNumConfigRegionAttributes arguments that correspond to
996  // block/thread identifiers and grid/block sizes, all having `index` type.
997  Region *body = result.addRegion();
998  if (parser.parseRegion(*body, regionArguments) ||
999  parser.parseOptionalAttrDict(result.attributes))
1000  return failure();
1001 
1002  SmallVector<int32_t, 11> segmentSizes(11, 1);
1003  segmentSizes.front() = asyncDependencies.size();
1004 
1005  if (!hasCluster) {
1006  segmentSizes[7] = 0;
1007  segmentSizes[8] = 0;
1008  segmentSizes[9] = 0;
1009  }
1010  segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1011  result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1012  parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1013  return success();
1014 }
1015 
1016 /// Simplify the gpu.launch when the range of a thread or block ID is
1017 /// trivially known to be one.
1018 struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1021  PatternRewriter &rewriter) const override {
1022  // If the range implies a single value for `id`, replace `id`'s uses by
1023  // zero.
1024  Value zero;
1025  bool simplified = false;
1026  auto constPropIdUses = [&](Value id, Value size) {
1027  // Check if size is trivially one.
1028  if (!matchPattern(size, m_One()))
1029  return;
1030  if (id.getUses().empty())
1031  return;
1032  if (!simplified) {
1033  // Create a zero value the first time.
1034  OpBuilder::InsertionGuard guard(rewriter);
1035  rewriter.setInsertionPointToStart(&op.getBody().front());
1036  zero =
1037  rewriter.create<arith::ConstantIndexOp>(op.getLoc(), /*value=*/0);
1038  }
1039  rewriter.replaceAllUsesWith(id, zero);
1040  simplified = true;
1041  };
1042  constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1043  constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1044  constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1045  constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1046  constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1047  constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1048 
1049  return success(simplified);
1050  }
1051 };
1052 
1053 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1054  MLIRContext *context) {
1055  rewrites.add<FoldLaunchArguments>(context);
1056 }
1057 
1058 /// Adds a new block argument that corresponds to buffers located in
1059 /// workgroup memory.
1060 BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1061  auto attrName = getNumWorkgroupAttributionsAttrName();
1062  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1063  (*this)->setAttr(attrName,
1064  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1065  return getBody().insertArgument(
1066  LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1067 }
1068 
1069 /// Adds a new block argument that corresponds to buffers located in
1070 /// private memory.
1071 BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1072  // Buffers on the private memory always come after buffers on the workgroup
1073  // memory.
1074  return getBody().addArgument(type, loc);
1075 }
1076 
1077 //===----------------------------------------------------------------------===//
1078 // LaunchFuncOp
1079 //===----------------------------------------------------------------------===//
1080 
1081 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1082  GPUFuncOp kernelFunc, KernelDim3 gridSize,
1083  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1084  ValueRange kernelOperands, Type asyncTokenType,
1085  ValueRange asyncDependencies,
1086  std::optional<KernelDim3> clusterSize) {
1087  result.addOperands(asyncDependencies);
1088  if (asyncTokenType)
1089  result.types.push_back(builder.getType<AsyncTokenType>());
1090 
1091  // Add grid and block sizes as op operands, followed by the data operands.
1092  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1093  getBlockSize.y, getBlockSize.z});
1094  if (clusterSize.has_value())
1095  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1096  if (dynamicSharedMemorySize)
1097  result.addOperands(dynamicSharedMemorySize);
1098  result.addOperands(kernelOperands);
1099  auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1100  auto kernelSymbol =
1101  SymbolRefAttr::get(kernelModule.getNameAttr(),
1102  {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1103 
1104  Properties &prop = result.getOrAddProperties<Properties>();
1105  prop.kernel = kernelSymbol;
1106  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1107  // Initialize the segment sizes to 1.
1108  for (auto &sz : prop.operandSegmentSizes)
1109  sz = 1;
1110  prop.operandSegmentSizes[0] = asyncDependencies.size();
1111  if (!clusterSize.has_value()) {
1112  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1113  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1114  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1115  }
1116  prop.operandSegmentSizes[segmentSizesLen - 3] =
1117  dynamicSharedMemorySize ? 1 : 0;
1118  prop.operandSegmentSizes[segmentSizesLen - 2] =
1119  static_cast<int32_t>(kernelOperands.size());
1120  prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1121 }
1122 
1123 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1124  SymbolRefAttr kernel, KernelDim3 gridSize,
1125  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1126  ValueRange kernelOperands, Value asyncObject,
1127  std::optional<KernelDim3> clusterSize) {
1128  // Add grid and block sizes as op operands, followed by the data operands.
1129  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1130  getBlockSize.y, getBlockSize.z});
1131  if (clusterSize.has_value())
1132  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1133  if (dynamicSharedMemorySize)
1134  result.addOperands(dynamicSharedMemorySize);
1135  result.addOperands(kernelOperands);
1136  if (asyncObject)
1137  result.addOperands(asyncObject);
1138  Properties &prop = result.getOrAddProperties<Properties>();
1139  prop.kernel = kernel;
1140  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1141  // Initialize the segment sizes to 1.
1142  for (auto &sz : prop.operandSegmentSizes)
1143  sz = 1;
1144  prop.operandSegmentSizes[0] = 0;
1145  if (!clusterSize.has_value()) {
1146  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1147  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1148  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1149  }
1150  prop.operandSegmentSizes[segmentSizesLen - 3] =
1151  dynamicSharedMemorySize ? 1 : 0;
1152  prop.operandSegmentSizes[segmentSizesLen - 2] =
1153  static_cast<int32_t>(kernelOperands.size());
1154  prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1155 }
1156 
1157 StringAttr LaunchFuncOp::getKernelModuleName() {
1158  return getKernel().getRootReference();
1159 }
1160 
1161 StringAttr LaunchFuncOp::getKernelName() {
1162  return getKernel().getLeafReference();
1163 }
1164 
1165 unsigned LaunchFuncOp::getNumKernelOperands() {
1166  return getKernelOperands().size();
1167 }
1168 
1169 Value LaunchFuncOp::getKernelOperand(unsigned i) {
1170  return getKernelOperands()[i];
1171 }
1172 
1173 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1174  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1175  return KernelDim3{operands[0], operands[1], operands[2]};
1176 }
1177 
1178 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1179  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1180  return KernelDim3{operands[3], operands[4], operands[5]};
1181 }
1182 
1183 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1184  assert(hasClusterSize() &&
1185  "cluster size is not set, check hasClusterSize() first");
1186  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1187  return KernelDim3{operands[6], operands[7], operands[8]};
1188 }
1189 
1191  auto module = (*this)->getParentOfType<ModuleOp>();
1192  if (!module)
1193  return emitOpError("expected to belong to a module");
1194 
1195  if (!module->getAttrOfType<UnitAttr>(
1196  GPUDialect::getContainerModuleAttrName()))
1197  return emitOpError("expected the closest surrounding module to have the '" +
1198  GPUDialect::getContainerModuleAttrName() +
1199  "' attribute");
1200 
1201  if (hasClusterSize()) {
1202  if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1203  getClusterSizeZ().getType() != getClusterSizeX().getType())
1204  return emitOpError()
1205  << "expects types of the cluster dimensions must be the same";
1206  }
1207 
1208  return success();
1209 }
1210 
1211 static ParseResult
1213  std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1214  Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1215  if (succeeded(parser.parseOptionalColon())) {
1216  if (parser.parseType(dimTy))
1217  return failure();
1218  } else {
1219  dimTy = IndexType::get(parser.getContext());
1220  }
1221  if (clusterValue.has_value()) {
1222  clusterXTy = clusterYTy = clusterZTy = dimTy;
1223  }
1224  return success();
1225 }
1226 
1227 static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1228  Value clusterValue, Type clusterXTy,
1229  Type clusterYTy, Type clusterZTy) {
1230  if (!dimTy.isIndex())
1231  printer << ": " << dimTy;
1232 }
1233 
1235  OpAsmParser &parser,
1237  SmallVectorImpl<Type> &argTypes) {
1238  if (parser.parseOptionalKeyword("args"))
1239  return success();
1240 
1241  auto parseElement = [&]() -> ParseResult {
1242  return failure(parser.parseOperand(argNames.emplace_back()) ||
1243  parser.parseColonType(argTypes.emplace_back()));
1244  };
1245 
1247  parseElement, " in argument list");
1248 }
1249 
1251  OperandRange operands, TypeRange types) {
1252  if (operands.empty())
1253  return;
1254  printer << "args(";
1255  llvm::interleaveComma(llvm::zip(operands, types), printer,
1256  [&](const auto &pair) {
1257  printer.printOperand(std::get<0>(pair));
1258  printer << " : ";
1259  printer.printType(std::get<1>(pair));
1260  });
1261  printer << ")";
1262 }
1263 
1264 //===----------------------------------------------------------------------===//
1265 // ShuffleOp
1266 //===----------------------------------------------------------------------===//
1267 
1268 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1269  int32_t offset, int32_t width, ShuffleMode mode) {
1270  build(builder, result, value,
1271  builder.create<arith::ConstantOp>(result.location,
1272  builder.getI32IntegerAttr(offset)),
1273  builder.create<arith::ConstantOp>(result.location,
1274  builder.getI32IntegerAttr(width)),
1275  mode);
1276 }
1277 
1278 //===----------------------------------------------------------------------===//
1279 // BarrierOp
1280 //===----------------------------------------------------------------------===//
1281 
1282 namespace {
1283 
1284 /// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1285 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1286  PatternRewriter &rewriter) {
1287  if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1288  rewriter.eraseOp(op);
1289  return success();
1290  }
1291  return failure();
1292 }
1293 
1294 } // end anonymous namespace
1295 
1296 void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1297  MLIRContext *context) {
1298  results.add(eraseRedundantGpuBarrierOps);
1299 }
1300 
1301 //===----------------------------------------------------------------------===//
1302 // GPUFuncOp
1303 //===----------------------------------------------------------------------===//
1304 
1305 /// Adds a new block argument that corresponds to buffers located in
1306 /// workgroup memory.
1307 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1308  auto attrName = getNumWorkgroupAttributionsAttrName();
1309  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1310  (*this)->setAttr(attrName,
1311  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1312  return getBody().insertArgument(
1313  getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1314 }
1315 
1316 /// Adds a new block argument that corresponds to buffers located in
1317 /// private memory.
1318 BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1319  // Buffers on the private memory always come after buffers on the workgroup
1320  // memory.
1321  return getBody().addArgument(type, loc);
1322 }
1323 
1324 void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1325  StringRef name, FunctionType type,
1326  TypeRange workgroupAttributions,
1327  TypeRange privateAttributions,
1328  ArrayRef<NamedAttribute> attrs) {
1329  OpBuilder::InsertionGuard g(builder);
1330 
1332  builder.getStringAttr(name));
1333  result.addAttribute(getFunctionTypeAttrName(result.name),
1334  TypeAttr::get(type));
1335  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1336  builder.getI64IntegerAttr(workgroupAttributions.size()));
1337  result.addAttributes(attrs);
1338  Region *body = result.addRegion();
1339  Block *entryBlock = builder.createBlock(body);
1340 
1341  // TODO: Allow passing in proper locations here.
1342  for (Type argTy : type.getInputs())
1343  entryBlock->addArgument(argTy, result.location);
1344  for (Type argTy : workgroupAttributions)
1345  entryBlock->addArgument(argTy, result.location);
1346  for (Type argTy : privateAttributions)
1347  entryBlock->addArgument(argTy, result.location);
1348 }
1349 
1350 /// Parses a GPU function memory attribution.
1351 ///
1352 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1353 /// (`private` `(` ssa-id-and-type-list `)`)?
1354 ///
1355 /// Note that this function parses only one of the two similar parts, with the
1356 /// keyword provided as argument.
1357 static ParseResult
1358 parseAttributions(OpAsmParser &parser, StringRef keyword,
1360  Attribute &attributionAttrs) {
1361  // If we could not parse the keyword, just assume empty list and succeed.
1362  if (failed(parser.parseOptionalKeyword(keyword)))
1363  return success();
1364 
1365  size_t existingArgs = args.size();
1366  ParseResult result =
1368  /*allowType=*/true, /*allowAttrs=*/true);
1369  if (failed(result))
1370  return result;
1371 
1372  bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1373  [](const OpAsmParser::Argument &arg) -> bool {
1374  return arg.attrs && !arg.attrs.empty();
1375  });
1376  if (!hadAttrs) {
1377  attributionAttrs = nullptr;
1378  return result;
1379  }
1380 
1381  Builder &builder = parser.getBuilder();
1382  SmallVector<Attribute> attributionAttrsVec;
1383  for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1384  if (!argument.attrs)
1385  attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1386  else
1387  attributionAttrsVec.push_back(argument.attrs);
1388  }
1389  attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1390  return result;
1391 }
1392 
1393 /// Parses a GPU function.
1394 ///
1395 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1396 /// (`->` function-result-list)? memory-attribution `kernel`?
1397 /// function-attributes? region
1400  SmallVector<DictionaryAttr> resultAttrs;
1401  SmallVector<Type> resultTypes;
1402  bool isVariadic;
1403 
1404  // Parse the function name.
1405  StringAttr nameAttr;
1406  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1407  result.attributes))
1408  return failure();
1409 
1410  auto signatureLocation = parser.getCurrentLocation();
1412  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1413  resultAttrs)))
1414  return failure();
1415 
1416  if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1417  return parser.emitError(signatureLocation)
1418  << "gpu.func requires named arguments";
1419 
1420  // Construct the function type. More types will be added to the region, but
1421  // not to the function type.
1422  Builder &builder = parser.getBuilder();
1423 
1424  SmallVector<Type> argTypes;
1425  for (auto &arg : entryArgs)
1426  argTypes.push_back(arg.type);
1427  auto type = builder.getFunctionType(argTypes, resultTypes);
1428  result.addAttribute(getFunctionTypeAttrName(result.name),
1429  TypeAttr::get(type));
1430 
1432  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1433  getResAttrsAttrName(result.name));
1434 
1435  Attribute workgroupAttributionAttrs;
1436  // Parse workgroup memory attributions.
1437  if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1438  entryArgs, workgroupAttributionAttrs)))
1439  return failure();
1440 
1441  // Store the number of operands we just parsed as the number of workgroup
1442  // memory attributions.
1443  unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1444  result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1445  builder.getI64IntegerAttr(numWorkgroupAttrs));
1446  if (workgroupAttributionAttrs)
1447  result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1448  workgroupAttributionAttrs);
1449 
1450  Attribute privateAttributionAttrs;
1451  // Parse private memory attributions.
1452  if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1453  entryArgs, privateAttributionAttrs)))
1454  return failure();
1455  if (privateAttributionAttrs)
1456  result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1457  privateAttributionAttrs);
1458 
1459  // Parse the kernel attribute if present.
1460  if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1461  result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1462  builder.getUnitAttr());
1463 
1464  // Parse attributes.
1466  return failure();
1467 
1468  // Parse the region. If no argument names were provided, take all names
1469  // (including those of attributions) from the entry block.
1470  auto *body = result.addRegion();
1471  return parser.parseRegion(*body, entryArgs);
1472 }
1473 
1474 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
1475  ArrayRef<BlockArgument> values,
1476  ArrayAttr attributes) {
1477  if (values.empty())
1478  return;
1479 
1480  p << ' ' << keyword << '(';
1481  llvm::interleaveComma(
1482  llvm::enumerate(values), p, [&p, attributes](auto pair) {
1483  BlockArgument v = pair.value();
1484  p << v << " : " << v.getType();
1485 
1486  size_t attributionIndex = pair.index();
1487  DictionaryAttr attrs;
1488  if (attributes && attributionIndex < attributes.size())
1489  attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1490  if (attrs)
1491  p.printOptionalAttrDict(attrs.getValue());
1492  });
1493  p << ')';
1494 }
1495 
1496 void GPUFuncOp::print(OpAsmPrinter &p) {
1497  p << ' ';
1498  p.printSymbolName(getName());
1499 
1500  FunctionType type = getFunctionType();
1501  function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1502  /*isVariadic=*/false,
1503  type.getResults());
1504 
1505  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1506  getWorkgroupAttribAttrs().value_or(nullptr));
1507  printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1508  getPrivateAttribAttrs().value_or(nullptr));
1509  if (isKernel())
1510  p << ' ' << getKernelKeyword();
1511 
1513  p, *this,
1514  {getNumWorkgroupAttributionsAttrName(),
1515  GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1516  getArgAttrsAttrName(), getResAttrsAttrName(),
1517  getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1518  p << ' ';
1519  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1520 }
1521 
1522 static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1523  StringAttr attrName) {
1524  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1525  if (!allAttrs || index >= allAttrs.size())
1526  return DictionaryAttr();
1527  return llvm::cast<DictionaryAttr>(allAttrs[index]);
1528 }
1529 
1530 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1531  return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1532 }
1533 
1534 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1535  return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1536 }
1537 
1538 static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1539  DictionaryAttr value, StringAttr attrName) {
1540  MLIRContext *ctx = op.getContext();
1541  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1542  SmallVector<Attribute> elements;
1543  if (allAttrs)
1544  elements.append(allAttrs.begin(), allAttrs.end());
1545  while (elements.size() <= index)
1546  elements.push_back(DictionaryAttr::get(ctx));
1547  if (!value)
1548  elements[index] = DictionaryAttr::get(ctx);
1549  else
1550  elements[index] = value;
1551  ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1552  op->setAttr(attrName, newValue);
1553 }
1554 
1555 void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1556  DictionaryAttr value) {
1557  setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1558 }
1559 
1560 void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1561  DictionaryAttr value) {
1562  setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1563 }
1564 
1565 static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1566  StringAttr name, StringAttr attrsName) {
1567  DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1568  if (!dict)
1569  return Attribute();
1570  return dict.get(name);
1571 }
1572 
1573 Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1574  StringAttr name) {
1575  assert(index < getNumWorkgroupAttributions() &&
1576  "index must map to a workgroup attribution");
1577  return getAttributionAttr(*this, index, name,
1578  getWorkgroupAttribAttrsAttrName());
1579 }
1580 
1581 Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1582  StringAttr name) {
1583  assert(index < getNumPrivateAttributions() &&
1584  "index must map to a private attribution");
1585  return getAttributionAttr(*this, index, name,
1586  getPrivateAttribAttrsAttrName());
1587 }
1588 
1589 static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1590  Attribute value, StringAttr attrsName) {
1591  MLIRContext *ctx = op.getContext();
1593  DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1594  if (oldDict)
1595  elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1596 
1597  bool found = false;
1598  bool mustSort = true;
1599  for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1600  if (elems[i].getName() == name) {
1601  found = true;
1602  if (!value) {
1603  std::swap(elems[i], elems[elems.size() - 1]);
1604  elems.pop_back();
1605  } else {
1606  mustSort = false;
1607  elems[i] = NamedAttribute(elems[i].getName(), value);
1608  }
1609  break;
1610  }
1611  }
1612  if (!found) {
1613  if (!value)
1614  return;
1615  elems.emplace_back(name, value);
1616  }
1617  if (mustSort) {
1618  DictionaryAttr::sortInPlace(elems);
1619  }
1620  auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1621  setAttributionAttrs(op, index, newDict, attrsName);
1622 }
1623 
1624 void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1625  Attribute value) {
1626  assert(index < getNumWorkgroupAttributions() &&
1627  "index must map to a workgroup attribution");
1628  setAttributionAttr(*this, index, name, value,
1629  getWorkgroupAttribAttrsAttrName());
1630 }
1631 
1632 void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1633  Attribute value) {
1634  assert(index < getNumPrivateAttributions() &&
1635  "index must map to a private attribution");
1636  setAttributionAttr(*this, index, name, value,
1637  getPrivateAttribAttrsAttrName());
1638 }
1639 
1640 LogicalResult GPUFuncOp::verifyType() {
1641  if (isKernel() && getFunctionType().getNumResults() != 0)
1642  return emitOpError() << "expected void return type for kernel function";
1643 
1644  return success();
1645 }
1646 
1647 /// Verifies the body of the function.
1648 LogicalResult GPUFuncOp::verifyBody() {
1649  if (empty())
1650  return emitOpError() << "expected body with at least one block";
1651  unsigned numFuncArguments = getNumArguments();
1652  unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1653  unsigned numBlockArguments = front().getNumArguments();
1654  if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1655  return emitOpError() << "expected at least "
1656  << numFuncArguments + numWorkgroupAttributions
1657  << " arguments to body region";
1658 
1659  ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1660  for (unsigned i = 0; i < numFuncArguments; ++i) {
1661  Type blockArgType = front().getArgument(i).getType();
1662  if (funcArgTypes[i] != blockArgType)
1663  return emitOpError() << "expected body region argument #" << i
1664  << " to be of type " << funcArgTypes[i] << ", got "
1665  << blockArgType;
1666  }
1667 
1668  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1669  GPUDialect::getWorkgroupAddressSpace())) ||
1670  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1671  GPUDialect::getPrivateAddressSpace())))
1672  return failure();
1673 
1674  return success();
1675 }
1676 
1677 static LogicalResult verifyKnownLaunchSizeAttr(gpu::GPUFuncOp op,
1678  StringRef attrName) {
1679  auto maybeAttr = op->getAttr(attrName);
1680  if (!maybeAttr)
1681  return success();
1682  auto array = llvm::dyn_cast<DenseI32ArrayAttr>(maybeAttr);
1683  if (!array)
1684  return op.emitOpError(attrName + " must be a dense i32 array");
1685  if (array.size() != 3)
1686  return op.emitOpError(attrName + " must contain exactly 3 elements");
1687  return success();
1688 }
1689 
1691  if (failed(verifyKnownLaunchSizeAttr(*this, getKnownBlockSizeAttrName())))
1692  return failure();
1693  if (failed(verifyKnownLaunchSizeAttr(*this, getKnownGridSizeAttrName())))
1694  return failure();
1695  return success();
1696 }
1697 
1698 //===----------------------------------------------------------------------===//
1699 // ReturnOp
1700 //===----------------------------------------------------------------------===//
1701 
1703  GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1704 
1705  FunctionType funType = function.getFunctionType();
1706 
1707  if (funType.getNumResults() != getOperands().size())
1708  return emitOpError()
1709  .append("expected ", funType.getNumResults(), " result operands")
1710  .attachNote(function.getLoc())
1711  .append("return type declared here");
1712 
1713  for (const auto &pair : llvm::enumerate(
1714  llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1715  auto [type, operand] = pair.value();
1716  if (type != operand.getType())
1717  return emitOpError() << "unexpected type `" << operand.getType()
1718  << "' for operand #" << pair.index();
1719  }
1720  return success();
1721 }
1722 
1723 //===----------------------------------------------------------------------===//
1724 // GPUModuleOp
1725 //===----------------------------------------------------------------------===//
1726 
1727 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1728  StringRef name, ArrayAttr targets,
1729  Attribute offloadingHandler) {
1730  ensureTerminator(*result.addRegion(), builder, result.location);
1731  result.attributes.push_back(builder.getNamedAttr(
1733 
1734  Properties &props = result.getOrAddProperties<Properties>();
1735  if (targets)
1736  props.targets = targets;
1737  props.offloadingHandler = offloadingHandler;
1738 }
1739 
1740 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1741  StringRef name, ArrayRef<Attribute> targets,
1742  Attribute offloadingHandler) {
1743  build(builder, result, name,
1744  targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1745  offloadingHandler);
1746 }
1747 
1749  StringAttr nameAttr;
1750  ArrayAttr targetsAttr;
1751 
1753  result.attributes))
1754  return failure();
1755 
1756  Properties &props = result.getOrAddProperties<Properties>();
1757 
1758  // Parse the optional offloadingHandler
1759  if (succeeded(parser.parseOptionalLess())) {
1760  if (parser.parseAttribute(props.offloadingHandler))
1761  return failure();
1762  if (parser.parseGreater())
1763  return failure();
1764  }
1765 
1766  // Parse the optional array of target attributes.
1767  OptionalParseResult targetsAttrResult =
1768  parser.parseOptionalAttribute(targetsAttr, Type{});
1769  if (targetsAttrResult.has_value()) {
1770  if (failed(*targetsAttrResult)) {
1771  return failure();
1772  }
1773  props.targets = targetsAttr;
1774  }
1775 
1776  // If module attributes are present, parse them.
1777  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1778  return failure();
1779 
1780  // Parse the module body.
1781  auto *body = result.addRegion();
1782  if (parser.parseRegion(*body, {}))
1783  return failure();
1784 
1785  // Ensure that this module has a valid terminator.
1786  GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
1787  return success();
1788 }
1789 
1791  p << ' ';
1792  p.printSymbolName(getName());
1793 
1794  if (Attribute attr = getOffloadingHandlerAttr()) {
1795  p << " <";
1796  p.printAttribute(attr);
1797  p << ">";
1798  }
1799 
1800  if (Attribute attr = getTargetsAttr()) {
1801  p << ' ';
1802  p.printAttribute(attr);
1803  p << ' ';
1804  }
1805 
1806  p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
1807  {mlir::SymbolTable::getSymbolAttrName(),
1808  getTargetsAttrName(),
1809  getOffloadingHandlerAttrName()});
1810  p << ' ';
1811  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1812  /*printBlockTerminators=*/false);
1813 }
1814 
1815 bool GPUModuleOp::hasTarget(Attribute target) {
1816  if (ArrayAttr targets = getTargetsAttr())
1817  return llvm::count(targets.getValue(), target);
1818  return false;
1819 }
1820 
1821 void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1822  ArrayAttr &targetsAttr = getProperties().targets;
1823  SmallVector<Attribute> targetsVector(targets);
1824  targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1825 }
1826 
1827 //===----------------------------------------------------------------------===//
1828 // GPUBinaryOp
1829 //===----------------------------------------------------------------------===//
1830 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1831  Attribute offloadingHandler, ArrayAttr objects) {
1832  auto &properties = result.getOrAddProperties<Properties>();
1833  result.attributes.push_back(builder.getNamedAttr(
1834  SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1835  properties.objects = objects;
1836  if (offloadingHandler)
1837  properties.offloadingHandler = offloadingHandler;
1838  else
1839  properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
1840 }
1841 
1842 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1843  Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1844  build(builder, result, name, offloadingHandler,
1845  objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
1846 }
1847 
1849  Attribute &offloadingHandler) {
1850  if (succeeded(parser.parseOptionalLess())) {
1851  if (parser.parseAttribute(offloadingHandler))
1852  return failure();
1853  if (parser.parseGreater())
1854  return failure();
1855  }
1856  if (!offloadingHandler)
1857  offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
1858  return success();
1859 }
1860 
1862  Attribute offloadingHandler) {
1863  if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
1864  printer << '<' << offloadingHandler << '>';
1865 }
1866 
1867 //===----------------------------------------------------------------------===//
1868 // GPUMemcpyOp
1869 //===----------------------------------------------------------------------===//
1870 
1871 LogicalResult MemcpyOp::verify() {
1872  auto srcType = getSrc().getType();
1873  auto dstType = getDst().getType();
1874 
1875  if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1876  return emitOpError("arguments have incompatible element type");
1877 
1878  if (failed(verifyCompatibleShape(srcType, dstType)))
1879  return emitOpError("arguments have incompatible shape");
1880 
1881  return success();
1882 }
1883 
1884 namespace {
1885 
1886 /// Erases a common case of copy ops where a destination value is used only by
1887 /// the copy op, alloc and dealloc ops.
1888 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1889  using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1890 
1891  LogicalResult matchAndRewrite(MemcpyOp op,
1892  PatternRewriter &rewriter) const override {
1893  Value dest = op.getDst();
1894  Operation *destDefOp = dest.getDefiningOp();
1895  // `dest` must be defined by an op having Allocate memory effect in order to
1896  // perform the folding.
1897  if (!destDefOp ||
1898  !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1899  return failure();
1900  // We can erase `op` iff `dest` has no other use apart from its
1901  // use by `op` and dealloc ops.
1902  if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1903  return user != op &&
1904  !hasSingleEffect<MemoryEffects::Free>(user, dest);
1905  }))
1906  return failure();
1907  // We can perform the folding if and only if op has a single async
1908  // dependency and produces an async token as result, or if it does not have
1909  // any async dependency and does not produce any async token result.
1910  if (op.getAsyncDependencies().size() > 1 ||
1911  ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1912  (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1913  return failure();
1914  rewriter.replaceOp(op, op.getAsyncDependencies());
1915  return success();
1916  }
1917 };
1918 
1919 } // end anonymous namespace
1920 
1921 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1922  MLIRContext *context) {
1923  results.add<EraseTrivialCopyOp>(context);
1924 }
1925 
1926 //===----------------------------------------------------------------------===//
1927 // GPU_SubgroupMmaLoadMatrixOp
1928 //===----------------------------------------------------------------------===//
1929 
1930 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1931  auto srcType = getSrcMemref().getType();
1932  auto resType = getRes().getType();
1933  auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
1934  auto operand = resMatrixType.getOperand();
1935  auto srcMemrefType = llvm::cast<MemRefType>(srcType);
1936 
1937  if (!isLastMemrefDimUnitStride(srcMemrefType))
1938  return emitError(
1939  "expected source memref most minor dim must have unit stride");
1940 
1941  if (!operand.equals("AOp") && !operand.equals("BOp") &&
1942  !operand.equals("COp"))
1943  return emitError("only AOp, BOp and COp can be loaded");
1944 
1945  return success();
1946 }
1947 
1948 //===----------------------------------------------------------------------===//
1949 // GPU_SubgroupMmaStoreMatrixOp
1950 //===----------------------------------------------------------------------===//
1951 
1952 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1953  auto srcType = getSrc().getType();
1954  auto dstType = getDstMemref().getType();
1955  auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
1956  auto dstMemrefType = llvm::cast<MemRefType>(dstType);
1957 
1958  if (!isLastMemrefDimUnitStride(dstMemrefType))
1959  return emitError(
1960  "expected destination memref most minor dim must have unit stride");
1961 
1962  if (!srcMatrixType.getOperand().equals("COp"))
1963  return emitError(
1964  "expected the operand matrix being stored to have 'COp' operand type");
1965 
1966  return success();
1967 }
1968 
1969 //===----------------------------------------------------------------------===//
1970 // GPU_SubgroupMmaComputeOp
1971 //===----------------------------------------------------------------------===//
1972 
1973 LogicalResult SubgroupMmaComputeOp::verify() {
1974  enum OperandMap { A, B, C };
1975  SmallVector<MMAMatrixType, 3> opTypes;
1976  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
1977  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
1978  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
1979 
1980  if (!opTypes[A].getOperand().equals("AOp") ||
1981  !opTypes[B].getOperand().equals("BOp") ||
1982  !opTypes[C].getOperand().equals("COp"))
1983  return emitError("operands must be in the order AOp, BOp, COp");
1984 
1985  ArrayRef<int64_t> aShape, bShape, cShape;
1986  aShape = opTypes[A].getShape();
1987  bShape = opTypes[B].getShape();
1988  cShape = opTypes[C].getShape();
1989 
1990  if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1991  bShape[1] != cShape[1])
1992  return emitError("operand shapes do not satisfy matmul constraints");
1993 
1994  return success();
1995 }
1996 
1997 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
1998  SmallVectorImpl<::mlir::OpFoldResult> &results) {
1999  return memref::foldMemRefCast(*this);
2000 }
2001 
2002 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2003  SmallVectorImpl<::mlir::OpFoldResult> &results) {
2004  return memref::foldMemRefCast(*this);
2005 }
2006 
2007 //===----------------------------------------------------------------------===//
2008 // GPU_WaitOp
2009 //===----------------------------------------------------------------------===//
2010 
2011 namespace {
2012 
2013 /// Remove gpu.wait op use of gpu.wait op def without async dependencies.
2014 /// %t = gpu.wait async [] // No async dependencies.
2015 /// ... gpu.wait ... [%t, ...] // %t can be removed.
2016 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2017 public:
2018  using OpRewritePattern::OpRewritePattern;
2019 
2020  LogicalResult matchAndRewrite(WaitOp op,
2021  PatternRewriter &rewriter) const final {
2022  auto predicate = [](Value value) {
2023  auto waitOp = value.getDefiningOp<WaitOp>();
2024  return waitOp && waitOp->getNumOperands() == 0;
2025  };
2026  if (llvm::none_of(op.getAsyncDependencies(), predicate))
2027  return failure();
2028  SmallVector<Value> validOperands;
2029  for (Value operand : op->getOperands()) {
2030  if (predicate(operand))
2031  continue;
2032  validOperands.push_back(operand);
2033  }
2034  rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2035  return success();
2036  }
2037 };
2038 
2039 /// Simplify trivial gpu.wait ops for the following patterns.
2040 /// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2041 /// dependencies).
2042 /// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2043 /// %t0.
2044 /// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2045 /// dependencies nor return any token.
2046 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2047 public:
2048  using OpRewritePattern::OpRewritePattern;
2049 
2050  LogicalResult matchAndRewrite(WaitOp op,
2051  PatternRewriter &rewriter) const final {
2052  // Erase gpu.wait ops that neither have any async dependencies nor return
2053  // any async token.
2054  if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2055  rewriter.eraseOp(op);
2056  return success();
2057  }
2058  // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2059  if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2060  op.getAsyncToken()) {
2061  rewriter.replaceOp(op, op.getAsyncDependencies());
2062  return success();
2063  }
2064  // Erase %t = gpu.wait async ... ops, where %t has no uses.
2065  if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2066  rewriter.eraseOp(op);
2067  return success();
2068  }
2069  return failure();
2070  }
2071 };
2072 
2073 } // end anonymous namespace
2074 
2075 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2076  MLIRContext *context) {
2077  results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2078 }
2079 
2080 //===----------------------------------------------------------------------===//
2081 // GPU_AllocOp
2082 //===----------------------------------------------------------------------===//
2083 
2084 LogicalResult AllocOp::verify() {
2085  auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2086 
2087  if (static_cast<int64_t>(getDynamicSizes().size()) !=
2088  memRefType.getNumDynamicDims())
2089  return emitOpError("dimension operand count does not equal memref "
2090  "dynamic dimension count");
2091 
2092  unsigned numSymbols = 0;
2093  if (!memRefType.getLayout().isIdentity())
2094  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2095  if (getSymbolOperands().size() != numSymbols) {
2096  return emitOpError(
2097  "symbol operand count does not equal memref symbol count");
2098  }
2099 
2100  return success();
2101 }
2102 
2103 namespace {
2104 
2105 /// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2106 /// `memref::AllocOp`.
2107 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2108  using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2109 
2110  LogicalResult matchAndRewrite(memref::DimOp dimOp,
2111  PatternRewriter &rewriter) const override {
2112  std::optional<int64_t> index = dimOp.getConstantIndex();
2113  if (!index)
2114  return failure();
2115 
2116  auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2117  if (!memrefType || !memrefType.isDynamicDim(index.value()))
2118  return failure();
2119 
2120  auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2121  if (!alloc)
2122  return failure();
2123 
2124  Value substituteOp = *(alloc.getDynamicSizes().begin() +
2125  memrefType.getDynamicDimIndex(index.value()));
2126  rewriter.replaceOp(dimOp, substituteOp);
2127  return success();
2128  }
2129 };
2130 
2131 } // namespace
2132 
2133 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2134  MLIRContext *context) {
2135  results.add<SimplifyDimOfAllocOp>(context);
2136 }
2137 
2138 //===----------------------------------------------------------------------===//
2139 // GPU object attribute
2140 //===----------------------------------------------------------------------===//
2141 
2142 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2143  Attribute target, CompilationTarget format,
2144  StringAttr object, DictionaryAttr properties) {
2145  if (!target)
2146  return emitError() << "the target attribute cannot be null";
2147  if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2148  return success();
2149  return emitError() << "the target attribute must implement or promise the "
2150  "`gpu::TargetAttrInterface`";
2151 }
2152 
2153 namespace {
2154 LogicalResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2155  StringAttr &object) {
2156  std::optional<CompilationTarget> formatResult;
2157  StringRef enumKeyword;
2158  auto loc = odsParser.getCurrentLocation();
2159  if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2160  formatResult = CompilationTarget::Fatbin;
2161  if (!formatResult &&
2162  (formatResult =
2163  gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2164  odsParser.parseEqual())
2165  return odsParser.emitError(loc, "expected an equal sign");
2166  if (!formatResult)
2167  return odsParser.emitError(loc, "expected keyword for GPU object format");
2168  FailureOr<StringAttr> objectResult =
2169  FieldParser<StringAttr>::parse(odsParser);
2170  if (failed(objectResult))
2171  return odsParser.emitError(odsParser.getCurrentLocation(),
2172  "failed to parse GPU_ObjectAttr parameter "
2173  "'object' which is to be a `StringAttr`");
2174  format = *formatResult;
2175  object = *objectResult;
2176  return success();
2177 }
2178 
2179 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2180  StringAttr object) {
2181  if (format != CompilationTarget::Fatbin)
2182  odsParser << stringifyEnum(format) << " = ";
2183  odsParser << object;
2184 }
2185 } // namespace
2186 
2187 //===----------------------------------------------------------------------===//
2188 // GPU select object attribute
2189 //===----------------------------------------------------------------------===//
2190 
2191 LogicalResult
2192 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2193  Attribute target) {
2194  // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2195  if (target) {
2196  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2197  if (intAttr.getInt() < 0) {
2198  return emitError() << "the object index must be positive";
2199  }
2200  } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2201  return emitError()
2202  << "the target attribute must be a GPU Target attribute";
2203  }
2204  }
2205  return success();
2206 }
2207 
2208 //===----------------------------------------------------------------------===//
2209 // DynamicSharedMemoryOp
2210 //===----------------------------------------------------------------------===//
2211 
2212 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2213  if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2214  return emitOpError() << "must be inside an op with symbol table";
2215 
2216  MemRefType memrefType = getResultMemref().getType();
2217  // Check address space
2218  if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2219  return emitOpError() << "address space must be "
2220  << gpu::AddressSpaceAttr::getMnemonic() << "<"
2221  << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2222  }
2223  if (memrefType.hasStaticShape()) {
2224  return emitOpError() << "result memref type must be memref<?xi8, "
2225  "#gpu.address_space<workgroup>>";
2226  }
2227  return success();
2228 }
2229 
2230 //===----------------------------------------------------------------------===//
2231 // GPU target options
2232 //===----------------------------------------------------------------------===//
2233 
2234 TargetOptions::TargetOptions(
2235  StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2236  StringRef cmdOptions, CompilationTarget compilationTarget,
2237  function_ref<SymbolTable *()> getSymbolTableCallback)
2238  : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
2239  cmdOptions, compilationTarget, getSymbolTableCallback) {}
2240 
2241 TargetOptions::TargetOptions(
2242  TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2243  StringRef cmdOptions, CompilationTarget compilationTarget,
2244  function_ref<SymbolTable *()> getSymbolTableCallback)
2245  : toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
2246  cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
2247  getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
2248 
2249 TypeID TargetOptions::getTypeID() const { return typeID; }
2250 
2251 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2252 
2253 ArrayRef<std::string> TargetOptions::getLinkFiles() const { return linkFiles; }
2254 
2255 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2256 
2257 SymbolTable *TargetOptions::getSymbolTable() const {
2258  return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2259 }
2260 
2261 CompilationTarget TargetOptions::getCompilationTarget() const {
2262  return compilationTarget;
2263 }
2264 
2265 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2266  return CompilationTarget::Fatbin;
2267 }
2268 
2269 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2270 TargetOptions::tokenizeCmdOptions() const {
2271  std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2272  llvm::StringSaver stringSaver(options.first);
2273  StringRef opts = cmdOptions;
2274  // For a correct tokenization of the command line options `opts` must be
2275  // unquoted, otherwise the tokenization function returns a single string: the
2276  // unquoted `cmdOptions` -which is not the desired behavior.
2277  // Remove any quotes if they are at the beginning and end of the string:
2278  if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2279  opts.consume_front("\""), opts.consume_back("\"");
2280  if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2281  opts.consume_front("'"), opts.consume_back("'");
2282 #ifdef _WIN32
2283  llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2284  /*MarkEOLs=*/false);
2285 #else
2286  llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2287  /*MarkEOLs=*/false);
2288 #endif // _WIN32
2289  return options;
2290 }
2291 
2293 
2294 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2295 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2296 
2297 #define GET_ATTRDEF_CLASSES
2298 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2299 
2300 #define GET_OP_CLASSES
2301 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2302 
2303 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
Definition: GPUDialect.cpp:404
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
Definition: GPUDialect.cpp:568
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
Definition: GPUDialect.cpp:420
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
Definition: GPUDialect.cpp:856
static LogicalResult verifyKnownLaunchSizeAttr(gpu::GPUFuncOp op, StringRef attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)
Prints a GPU function memory attribution.
Definition: GPUDialect.cpp:454
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
Definition: GPUDialect.cpp:546
static std::string getSparseHandleKeyword(SparseHandleKind kind)
Definition: GPUDialect.cpp:220
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
Definition: GPUDialect.cpp:581
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
Definition: GPUDialect.cpp:443
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
Definition: GPUDialect.cpp:492
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
Definition: GPUDialect.cpp:807
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
Definition: GPUDialect.cpp:466
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MINUI(lhs, rhs)
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:263
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:315
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:179
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:96
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:93
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
IndexType getIndexType()
Definition: Builders.cpp:71
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition: Builders.cpp:120
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:110
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:100
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:202
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:49
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:216
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
Definition: Operation.cpp:256
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:809
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:640
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:59
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:77
bool isIndex() const
Definition: Types.cpp:56
bool isF32() const
Definition: Types.cpp:51
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:89
bool isF16() const
Definition: Types.cpp:49
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static ConcreteT get(MLIRContext *ctx, Args &&...args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:130
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Definition: GPUDialect.cpp:136
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:121
Type getElementType() const
Get elementType of a single element.
Definition: GPUDialect.cpp:140
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
Definition: GPUDialect.cpp:144
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Definition: GPUDialect.cpp:142
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
Definition: GPUDialect.cpp:127
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
Definition: GPUDialect.cpp:151
unsigned getNumDims() const
Get number of dims.
Definition: GPUDialect.cpp:134
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
Definition: GPUDialect.cpp:622
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:19
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:38