MLIR  19.0.0git
GPUDialect.cpp
Go to the documentation of this file.
1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU kernel-related dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/SymbolTable.h"
29 #include "mlir/IR/TypeUtilities.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/TypeSwitch.h"
36 #include "llvm/Support/CommandLine.h"
37 #include "llvm/Support/ErrorHandling.h"
38 #include "llvm/Support/StringSaver.h"
39 #include <cassert>
40 
41 using namespace mlir;
42 using namespace mlir::gpu;
43 
44 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
45 
46 //===----------------------------------------------------------------------===//
47 // GPU Device Mapping Attributes
48 //===----------------------------------------------------------------------===//
49 
50 int64_t GPUBlockMappingAttr::getMappingId() const {
51  return static_cast<int64_t>(getBlock());
52 }
53 
54 bool GPUBlockMappingAttr::isLinearMapping() const {
55  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
56 }
57 
58 int64_t GPUBlockMappingAttr::getRelativeIndex() const {
59  return isLinearMapping()
60  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
61  : getMappingId();
62 }
63 
64 int64_t GPUWarpgroupMappingAttr::getMappingId() const {
65  return static_cast<int64_t>(getWarpgroup());
66 }
67 
68 bool GPUWarpgroupMappingAttr::isLinearMapping() const {
69  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
70 }
71 
72 int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
73  return isLinearMapping()
74  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
75  : getMappingId();
76 }
77 
78 int64_t GPUWarpMappingAttr::getMappingId() const {
79  return static_cast<int64_t>(getWarp());
80 }
81 
82 bool GPUWarpMappingAttr::isLinearMapping() const {
83  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
84 }
85 
86 int64_t GPUWarpMappingAttr::getRelativeIndex() const {
87  return isLinearMapping()
88  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
89  : getMappingId();
90 }
91 
92 int64_t GPUThreadMappingAttr::getMappingId() const {
93  return static_cast<int64_t>(getThread());
94 }
95 
96 bool GPUThreadMappingAttr::isLinearMapping() const {
97  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
98 }
99 
100 int64_t GPUThreadMappingAttr::getRelativeIndex() const {
101  return isLinearMapping()
102  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
103  : getMappingId();
104 }
105 
106 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
107  return static_cast<int64_t>(getAddressSpace());
108 }
109 
110 bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
111  llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
112 }
113 
114 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
115  llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
116 }
117 
118 //===----------------------------------------------------------------------===//
119 // MMAMatrixType
120 //===----------------------------------------------------------------------===//
121 
123  StringRef operand) {
124  return Base::get(elementType.getContext(), shape, elementType, operand);
125 }
126 
129  ArrayRef<int64_t> shape, Type elementType,
130  StringRef operand) {
131  return Base::getChecked(emitError, elementType.getContext(), shape,
132  elementType, operand);
133 }
134 
135 unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
136 
138  return getImpl()->getShape();
139 }
140 
141 Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
142 
143 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
144 
146  return elementType.isF16() || elementType.isF32() ||
147  elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
148  elementType.isInteger(32);
149 }
150 
153  ArrayRef<int64_t> shape, Type elementType,
154  StringRef operand) {
155  if (operand != "AOp" && operand != "BOp" && operand != "COp")
156  return emitError() << "operand expected to be one of AOp, BOp or COp";
157 
158  if (shape.size() != 2)
159  return emitError() << "MMAMatrixType must have exactly two dimensions";
160 
161  if (!MMAMatrixType::isValidElementType(elementType))
162  return emitError()
163  << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
164 
165  return success();
166 }
167 
168 //===----------------------------------------------------------------------===//
169 // GPUDialect
170 //===----------------------------------------------------------------------===//
171 
172 bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
173  if (!memorySpace)
174  return false;
175  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
176  return gpuAttr.getValue() == getWorkgroupAddressSpace();
177  return false;
178 }
179 
180 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
181  Attribute memorySpace = type.getMemorySpace();
182  return isWorkgroupMemoryAddressSpace(memorySpace);
183 }
184 
185 bool GPUDialect::isKernel(Operation *op) {
186  UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
187  return static_cast<bool>(isKernelAttr);
188 }
189 
190 namespace {
191 /// This class defines the interface for handling inlining with gpu
192 /// operations.
193 struct GPUInlinerInterface : public DialectInlinerInterface {
195 
196  /// All gpu dialect ops can be inlined.
197  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
198  return true;
199  }
200 };
201 } // namespace
202 
203 void GPUDialect::initialize() {
204  addTypes<AsyncTokenType>();
205  addTypes<MMAMatrixType>();
206  addTypes<SparseDnTensorHandleType>();
207  addTypes<SparseSpMatHandleType>();
208  addTypes<SparseSpGEMMOpHandleType>();
209  addOperations<
210 #define GET_OP_LIST
211 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
212  >();
213  addAttributes<
214 #define GET_ATTRDEF_LIST
215 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
216  >();
217  addInterfaces<GPUInlinerInterface>();
218  declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
219  TerminatorOp>();
220 }
221 
222 static std::string getSparseHandleKeyword(SparseHandleKind kind) {
223  switch (kind) {
225  return "sparse.dntensor_handle";
227  return "sparse.spmat_handle";
229  return "sparse.spgemmop_handle";
230  }
231  llvm_unreachable("unknown sparse handle kind");
232  return "";
233 }
234 
236  // Parse the main keyword for the type.
237  StringRef keyword;
238  if (parser.parseKeyword(&keyword))
239  return Type();
240  MLIRContext *context = getContext();
241 
242  // Handle 'async token' types.
243  if (keyword == "async.token")
244  return AsyncTokenType::get(context);
245 
246  if (keyword == "mma_matrix") {
247  SMLoc beginLoc = parser.getNameLoc();
248 
249  // Parse '<'.
250  if (parser.parseLess())
251  return nullptr;
252 
253  // Parse the size and elementType.
254  SmallVector<int64_t> shape;
255  Type elementType;
256  if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
257  parser.parseType(elementType))
258  return nullptr;
259 
260  // Parse ','
261  if (parser.parseComma())
262  return nullptr;
263 
264  // Parse operand.
265  std::string operand;
266  if (failed(parser.parseOptionalString(&operand)))
267  return nullptr;
268 
269  // Parse '>'.
270  if (parser.parseGreater())
271  return nullptr;
272 
274  parser.getEncodedSourceLoc(beginLoc)),
275  shape, elementType, operand);
276  }
277 
279  return SparseDnTensorHandleType::get(context);
281  return SparseSpMatHandleType::get(context);
283  return SparseSpGEMMOpHandleType::get(context);
284 
285  parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
286  return Type();
287 }
288 // TODO: print refined type here. Notice that should be corresponding to the
289 // parser
290 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
291  TypeSwitch<Type>(type)
292  .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
293  .Case<SparseDnTensorHandleType>([&](Type) {
295  })
296  .Case<SparseSpMatHandleType>(
298  .Case<SparseSpGEMMOpHandleType>([&](Type) {
300  })
301  .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
302  os << "mma_matrix<";
303  auto shape = fragTy.getShape();
304  for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
305  os << *dim << 'x';
306  os << shape.back() << 'x' << fragTy.getElementType();
307  os << ", \"" << fragTy.getOperand() << "\"" << '>';
308  })
309  .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
310 }
311 
312 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
313  NamedAttribute attr) {
314  if (!llvm::isa<UnitAttr>(attr.getValue()) ||
315  attr.getName() != getContainerModuleAttrName())
316  return success();
317 
318  auto module = dyn_cast<ModuleOp>(op);
319  if (!module)
320  return op->emitError("expected '")
321  << getContainerModuleAttrName() << "' attribute to be attached to '"
322  << ModuleOp::getOperationName() << '\'';
323 
324  auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
325  // Ignore launches that are nested more or less deep than functions in the
326  // module we are currently checking.
327  if (!launchOp->getParentOp() ||
328  launchOp->getParentOp()->getParentOp() != module)
329  return success();
330 
331  // Ignore launch ops with missing attributes here. The errors will be
332  // reported by the verifiers of those ops.
333  if (!launchOp->getAttrOfType<SymbolRefAttr>(
334  LaunchFuncOp::getKernelAttrName(launchOp->getName())))
335  return success();
336 
337  // Check that `launch_func` refers to a well-formed GPU kernel container.
338  StringAttr kernelContainerName = launchOp.getKernelModuleName();
339  Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
340  if (!kernelContainer)
341  return launchOp.emitOpError()
342  << "kernel container '" << kernelContainerName.getValue()
343  << "' is undefined";
344 
345  // If the container is a GPU binary op return success.
346  if (isa<BinaryOp>(kernelContainer))
347  return success();
348 
349  auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
350  if (!kernelModule)
351  return launchOp.emitOpError()
352  << "kernel module '" << kernelContainerName.getValue()
353  << "' is undefined";
354 
355  // Check that `launch_func` refers to a well-formed kernel function.
356  Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
357  if (!kernelFunc)
358  return launchOp.emitOpError("kernel function '")
359  << launchOp.getKernel() << "' is undefined";
360  auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
361  if (!kernelConvertedFunction) {
362  InFlightDiagnostic diag = launchOp.emitOpError()
363  << "referenced kernel '" << launchOp.getKernel()
364  << "' is not a function";
365  diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
366  return diag;
367  }
368 
369  if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
370  GPUDialect::getKernelFuncAttrName()))
371  return launchOp.emitOpError("kernel function is missing the '")
372  << GPUDialect::getKernelFuncAttrName() << "' attribute";
373 
374  // TODO: If the kernel isn't a GPU function (which happens during separate
375  // compilation), do not check type correspondence as it would require the
376  // verifier to be aware of the type conversion.
377  auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
378  if (!kernelGPUFunction)
379  return success();
380 
381  unsigned actualNumArguments = launchOp.getNumKernelOperands();
382  unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
383  if (expectedNumArguments != actualNumArguments)
384  return launchOp.emitOpError("got ")
385  << actualNumArguments << " kernel operands but expected "
386  << expectedNumArguments;
387 
388  auto functionType = kernelGPUFunction.getFunctionType();
389  for (unsigned i = 0; i < expectedNumArguments; ++i) {
390  if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
391  return launchOp.emitOpError("type of function argument ")
392  << i << " does not match";
393  }
394  }
395 
396  return success();
397  });
398 
399  return walkResult.wasInterrupted() ? failure() : success();
400 }
401 
402 /// Parses an optional list of async operands with an optional leading keyword.
403 /// (`async`)? (`[` ssa-id-list `]`)?
404 ///
405 /// This method is used by the tablegen assembly format for async ops as well.
407  OpAsmParser &parser, Type &asyncTokenType,
409  auto loc = parser.getCurrentLocation();
410  if (succeeded(parser.parseOptionalKeyword("async"))) {
411  if (parser.getNumResults() == 0)
412  return parser.emitError(loc, "needs to be named when marked 'async'");
413  asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
414  }
415  return parser.parseOperandList(asyncDependencies,
417 }
418 
419 /// Prints optional async dependencies with its leading keyword.
420 /// (`async`)? (`[` ssa-id-list `]`)?
421 // Used by the tablegen assembly format for several async ops.
423  Type asyncTokenType,
424  OperandRange asyncDependencies) {
425  if (asyncTokenType)
426  printer << "async";
427  if (asyncDependencies.empty())
428  return;
429  if (asyncTokenType)
430  printer << ' ';
431  printer << '[';
432  llvm::interleaveComma(asyncDependencies, printer);
433  printer << ']';
434 }
435 
436 // GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
437 /// Parses a GPU function memory attribution.
438 ///
439 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
440 /// (`private` `(` ssa-id-and-type-list `)`)?
441 ///
442 /// Note that this function parses only one of the two similar parts, with the
443 /// keyword provided as argument.
444 static ParseResult
445 parseAttributions(OpAsmParser &parser, StringRef keyword,
447  // If we could not parse the keyword, just assume empty list and succeed.
448  if (failed(parser.parseOptionalKeyword(keyword)))
449  return success();
450 
452  /*allowType=*/true);
453 }
454 
455 /// Prints a GPU function memory attribution.
456 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
457  ArrayRef<BlockArgument> values) {
458  if (values.empty())
459  return;
460 
461  p << ' ' << keyword << '(';
462  llvm::interleaveComma(
463  values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); });
464  p << ')';
465 }
466 
467 /// Verifies a GPU function memory attribution.
469  ArrayRef<BlockArgument> attributions,
470  gpu::AddressSpace memorySpace) {
471  for (Value v : attributions) {
472  auto type = llvm::dyn_cast<MemRefType>(v.getType());
473  if (!type)
474  return op->emitOpError() << "expected memref type in attribution";
475 
476  // We can only verify the address space if it hasn't already been lowered
477  // from the AddressSpaceAttr to a target-specific numeric value.
478  auto addressSpace =
479  llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
480  if (!addressSpace)
481  continue;
482  if (addressSpace.getValue() != memorySpace)
483  return op->emitOpError()
484  << "expected memory space " << stringifyAddressSpace(memorySpace)
485  << " in attribution";
486  }
487  return success();
488 }
489 
490 //===----------------------------------------------------------------------===//
491 // AllReduceOp
492 //===----------------------------------------------------------------------===//
493 
494 static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
495  Type resType) {
496  using Kind = gpu::AllReduceOperation;
497  if (llvm::is_contained(
498  {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
499  opName)) {
500  if (!isa<FloatType>(resType))
501  return failure();
502  }
503 
504  if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
505  Kind::AND, Kind::OR, Kind::XOR},
506  opName)) {
507  if (!isa<IntegerType>(resType))
508  return failure();
509  }
510 
511  return success();
512 }
513 
514 LogicalResult gpu::AllReduceOp::verifyRegions() {
515  if (getBody().empty() != getOp().has_value())
516  return emitError("expected either an op attribute or a non-empty body");
517  if (!getBody().empty()) {
518  if (getBody().getNumArguments() != 2)
519  return emitError("expected two region arguments");
520  for (auto argument : getBody().getArguments()) {
521  if (argument.getType() != getType())
522  return emitError("incorrect region argument type");
523  }
524  unsigned yieldCount = 0;
525  for (Block &block : getBody()) {
526  if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
527  if (yield.getNumOperands() != 1)
528  return emitError("expected one gpu.yield operand");
529  if (yield.getOperand(0).getType() != getType())
530  return emitError("incorrect gpu.yield type");
531  ++yieldCount;
532  }
533  }
534  if (yieldCount == 0)
535  return emitError("expected gpu.yield op in region");
536  } else {
537  gpu::AllReduceOperation opName = *getOp();
538  if (failed(verifyReduceOpAndType(opName, getType()))) {
539  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
540  << "` reduction operation is not compatible with type "
541  << getType();
542  }
543  }
544 
545  return success();
546 }
547 
549  auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
550  if (!launchOp)
551  return false;
552 
553  Region &body = launchOp.getBody();
554  assert(!body.empty() && "Invalid region");
555 
556  // Only convert ops in gpu::launch entry block for now.
557  return op->getBlock() == &body.front();
558 }
559 
560 OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
561  if (!getUniform() && canMakeGroupOpUniform(*this)) {
562  setUniform(true);
563  return getResult();
564  }
565 
566  return nullptr;
567 }
568 
569 // TODO: Support optional custom attributes (without dialect prefix).
571  AllReduceOperationAttr &attr) {
572  StringRef enumStr;
573  if (!parser.parseOptionalKeyword(&enumStr)) {
574  std::optional<AllReduceOperation> op =
575  gpu::symbolizeAllReduceOperation(enumStr);
576  if (!op)
577  return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
578  attr = AllReduceOperationAttr::get(parser.getContext(), *op);
579  }
580  return success();
581 }
582 
583 static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
584  AllReduceOperationAttr attr) {
585  if (attr)
586  attr.print(printer);
587 }
588 
589 //===----------------------------------------------------------------------===//
590 // SubgroupReduceOp
591 //===----------------------------------------------------------------------===//
592 
594  Type elemType = getType();
595  if (auto vecTy = dyn_cast<VectorType>(elemType)) {
596  if (vecTy.isScalable())
597  return emitOpError() << "is not compatible with scalable vector types";
598 
599  elemType = vecTy.getElementType();
600  }
601 
602  gpu::AllReduceOperation opName = getOp();
603  if (failed(verifyReduceOpAndType(opName, elemType))) {
604  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
605  << "` reduction operation is not compatible with type "
606  << getType();
607  }
608  return success();
609 }
610 
611 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
612  if (!getUniform() && canMakeGroupOpUniform(*this)) {
613  setUniform(true);
614  return getResult();
615  }
616 
617  return nullptr;
618 }
619 
620 //===----------------------------------------------------------------------===//
621 // AsyncOpInterface
622 //===----------------------------------------------------------------------===//
623 
625  op->insertOperands(0, {token});
626  if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
627  return;
628  auto attrName =
630  auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
631 
632  // Async dependencies is the only variadic operand.
633  if (!sizeAttr)
634  return;
635 
636  SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
637  ++sizes.front();
638  op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
639 }
640 
641 //===----------------------------------------------------------------------===//
642 // LaunchOp
643 //===----------------------------------------------------------------------===//
644 
645 void LaunchOp::build(OpBuilder &builder, OperationState &result,
646  Value gridSizeX, Value gridSizeY, Value gridSizeZ,
647  Value getBlockSizeX, Value getBlockSizeY,
648  Value getBlockSizeZ, Value dynamicSharedMemorySize,
649  Type asyncTokenType, ValueRange asyncDependencies,
650  TypeRange workgroupAttributions,
651  TypeRange privateAttributions, Value clusterSizeX,
652  Value clusterSizeY, Value clusterSizeZ) {
653  OpBuilder::InsertionGuard g(builder);
654 
655  // Add a WorkGroup attribution attribute. This attribute is required to
656  // identify private attributions in the list of block argguments.
657  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
658  builder.getI64IntegerAttr(workgroupAttributions.size()));
659 
660  // Add Op operands.
661  result.addOperands(asyncDependencies);
662  if (asyncTokenType)
663  result.types.push_back(builder.getType<AsyncTokenType>());
664 
665  // Add grid and block sizes as op operands, followed by the data operands.
666  result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
667  getBlockSizeY, getBlockSizeZ});
668  if (clusterSizeX)
669  result.addOperands(clusterSizeX);
670  if (clusterSizeY)
671  result.addOperands(clusterSizeY);
672  if (clusterSizeZ)
673  result.addOperands(clusterSizeZ);
674  if (dynamicSharedMemorySize)
675  result.addOperands(dynamicSharedMemorySize);
676 
677  // Create a kernel body region with kNumConfigRegionAttributes + N memory
678  // attributions, where the first kNumConfigRegionAttributes arguments have
679  // `index` type and the rest have the same types as the data operands.
680  Region *kernelRegion = result.addRegion();
681  Block *body = builder.createBlock(kernelRegion);
682  // TODO: Allow passing in proper locations here.
683  for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
684  body->addArgument(builder.getIndexType(), result.location);
685  // Add WorkGroup & Private attributions to the region arguments.
686  for (Type argTy : workgroupAttributions)
687  body->addArgument(argTy, result.location);
688  for (Type argTy : privateAttributions)
689  body->addArgument(argTy, result.location);
690  // Fill OperandSegmentSize Attribute.
691  SmallVector<int32_t, 11> segmentSizes(11, 1);
692  segmentSizes.front() = asyncDependencies.size();
693  segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
694  segmentSizes[7] = clusterSizeX ? 1 : 0;
695  segmentSizes[8] = clusterSizeY ? 1 : 0;
696  segmentSizes[9] = clusterSizeZ ? 1 : 0;
697  result.addAttribute(getOperandSegmentSizeAttr(),
698  builder.getDenseI32ArrayAttr(segmentSizes));
699 }
700 
701 KernelDim3 LaunchOp::getBlockIds() {
702  assert(!getBody().empty() && "LaunchOp body must not be empty.");
703  auto args = getBody().getArguments();
704  return KernelDim3{args[0], args[1], args[2]};
705 }
706 
707 KernelDim3 LaunchOp::getThreadIds() {
708  assert(!getBody().empty() && "LaunchOp body must not be empty.");
709  auto args = getBody().getArguments();
710  return KernelDim3{args[3], args[4], args[5]};
711 }
712 
713 KernelDim3 LaunchOp::getGridSize() {
714  assert(!getBody().empty() && "LaunchOp body must not be empty.");
715  auto args = getBody().getArguments();
716  return KernelDim3{args[6], args[7], args[8]};
717 }
718 
720  assert(!getBody().empty() && "LaunchOp body must not be empty.");
721  auto args = getBody().getArguments();
722  return KernelDim3{args[9], args[10], args[11]};
723 }
724 
725 std::optional<KernelDim3> LaunchOp::getClusterIds() {
726  assert(!getBody().empty() && "LaunchOp body must not be empty.");
727  if (!hasClusterSize())
728  return std::nullopt;
729  auto args = getBody().getArguments();
730  return KernelDim3{args[12], args[13], args[14]};
731 }
732 
733 std::optional<KernelDim3> LaunchOp::getClusterSize() {
734  assert(!getBody().empty() && "LaunchOp body must not be empty.");
735  if (!hasClusterSize())
736  return std::nullopt;
737  auto args = getBody().getArguments();
738  return KernelDim3{args[15], args[16], args[17]};
739 }
740 
741 KernelDim3 LaunchOp::getGridSizeOperandValues() {
742  auto operands = getOperands().drop_front(getAsyncDependencies().size());
743  return KernelDim3{operands[0], operands[1], operands[2]};
744 }
745 
746 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
747  auto operands = getOperands().drop_front(getAsyncDependencies().size());
748  return KernelDim3{operands[3], operands[4], operands[5]};
749 }
750 
751 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
752  auto operands = getOperands().drop_front(getAsyncDependencies().size());
753  if (!hasClusterSize())
754  return std::nullopt;
755  return KernelDim3{operands[6], operands[7], operands[8]};
756 }
757 
759  if (!(hasClusterSize()) &&
760  (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
761  return emitOpError() << "cluster size must be all present";
762  return success();
763 }
764 
765 LogicalResult LaunchOp::verifyRegions() {
766  // Kernel launch takes kNumConfigOperands leading operands for grid/block
767  // sizes and transforms them into kNumConfigRegionAttributes region arguments
768  // for block/thread identifiers and grid/block sizes.
769  if (!getBody().empty()) {
770  if (getBody().getNumArguments() <
771  kNumConfigRegionAttributes + getNumWorkgroupAttributions())
772  return emitOpError("unexpected number of region arguments");
773  }
774 
775  // Verify Attributions Address Spaces.
776  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
777  GPUDialect::getWorkgroupAddressSpace())) ||
778  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
779  GPUDialect::getPrivateAddressSpace())))
780  return failure();
781 
782  // Block terminators without successors are expected to exit the kernel region
783  // and must be `gpu.terminator`.
784  for (Block &block : getBody()) {
785  if (block.empty())
786  continue;
787  if (block.back().getNumSuccessors() != 0)
788  continue;
789  if (!isa<gpu::TerminatorOp>(&block.back())) {
790  return block.back()
791  .emitError()
792  .append("expected '", gpu::TerminatorOp::getOperationName(),
793  "' or a terminator with successors")
794  .attachNote(getLoc())
795  .append("in '", LaunchOp::getOperationName(), "' body region");
796  }
797  }
798 
799  if (getNumResults() == 0 && getAsyncToken())
800  return emitOpError("needs to be named when async keyword is specified");
801 
802  return success();
803 }
804 
805 // Pretty-print the kernel grid/block size assignment as
806 // (%iter-x, %iter-y, %iter-z) in
807 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
808 // where %size-* and %iter-* will correspond to the body region arguments.
810  KernelDim3 operands, KernelDim3 ids) {
811  p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
812  p << size.x << " = " << operands.x << ", ";
813  p << size.y << " = " << operands.y << ", ";
814  p << size.z << " = " << operands.z << ')';
815 }
816 
817 void LaunchOp::print(OpAsmPrinter &p) {
818  if (getAsyncToken()) {
819  p << " async";
820  if (!getAsyncDependencies().empty())
821  p << " [" << getAsyncDependencies() << ']';
822  }
823  // Print the launch configuration.
824  if (hasClusterSize()) {
825  p << ' ' << getClustersKeyword();
826  printSizeAssignment(p, getClusterSize().value(),
827  getClusterSizeOperandValues().value(),
828  getClusterIds().value());
829  }
830  p << ' ' << getBlocksKeyword();
831  printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
832  getBlockIds());
833  p << ' ' << getThreadsKeyword();
834  printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
835  getThreadIds());
836  if (getDynamicSharedMemorySize())
837  p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
838  << getDynamicSharedMemorySize();
839 
840  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
841  printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
842 
843  p << ' ';
844 
845  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
846  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
847  LaunchOp::getOperandSegmentSizeAttr(),
848  getNumWorkgroupAttributionsAttrName()});
849 }
850 
851 // Parse the size assignment blocks for blocks and threads. These have the form
852 // (%region_arg, %region_arg, %region_arg) in
853 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
854 // where %region_arg are percent-identifiers for the region arguments to be
855 // introduced further (SSA defs), and %operand are percent-identifiers for the
856 // SSA value uses.
857 static ParseResult
862  assert(indices.size() == 3 && "space for three indices expected");
865  /*allowResultNumber=*/false) ||
866  parser.parseKeyword("in") || parser.parseLParen())
867  return failure();
868  std::move(args.begin(), args.end(), indices.begin());
869 
870  for (int i = 0; i < 3; ++i) {
871  if (i != 0 && parser.parseComma())
872  return failure();
873  if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
874  parser.parseEqual() || parser.parseOperand(sizes[i]))
875  return failure();
876  }
877 
878  return parser.parseRParen();
879 }
880 
881 /// Parses a Launch operation.
882 /// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
883 /// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
884 /// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
885 /// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
886 /// memory-attribution
887 /// region attr-dict?
888 /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
890  // Sizes of the grid and block.
892  sizes(LaunchOp::kNumConfigOperands);
893 
894  // Actual (data) operands passed to the kernel.
896 
897  // Region arguments to be created.
899  LaunchOp::kNumConfigRegionAttributes);
900 
901  // Parse optional async dependencies.
903  Type asyncTokenType;
904  if (failed(
905  parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
906  parser.resolveOperands(asyncDependencies, asyncTokenType,
907  result.operands))
908  return failure();
909  if (parser.getNumResults() > 0)
910  result.types.push_back(asyncTokenType);
911 
912  bool hasCluster = false;
913  if (succeeded(
914  parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) {
915  hasCluster = true;
916  sizes.resize(9);
917  regionArgs.resize(18);
918  }
920  MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
921 
922  // Last three segment assigns the cluster size. In the region argument
923  // list, this is last 6 arguments.
924  if (hasCluster) {
925  if (parseSizeAssignment(parser, sizesRef.drop_front(6),
926  regionArgsRef.slice(15, 3),
927  regionArgsRef.slice(12, 3)))
928  return failure();
929  }
930  // Parse the size assignment segments: the first segment assigns grid sizes
931  // and defines values for block identifiers; the second segment assigns block
932  // sizes and defines values for thread identifiers. In the region argument
933  // list, identifiers precede sizes, and block-related values precede
934  // thread-related values.
935  if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
936  parseSizeAssignment(parser, sizesRef.take_front(3),
937  regionArgsRef.slice(6, 3),
938  regionArgsRef.slice(0, 3)) ||
939  parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
940  parseSizeAssignment(parser, sizesRef.drop_front(3),
941  regionArgsRef.slice(9, 3),
942  regionArgsRef.slice(3, 3)) ||
943  parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
944  result.operands))
945  return failure();
946 
947  OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
948  bool hasDynamicSharedMemorySize = false;
949  if (!parser.parseOptionalKeyword(
950  LaunchOp::getDynamicSharedMemorySizeKeyword())) {
951  hasDynamicSharedMemorySize = true;
952  if (parser.parseOperand(dynamicSharedMemorySize) ||
953  parser.resolveOperand(dynamicSharedMemorySize,
954  parser.getBuilder().getI32Type(),
955  result.operands))
956  return failure();
957  }
958 
959  // Create the region arguments, it has kNumConfigRegionAttributes arguments
960  // that correspond to block/thread identifiers and grid/block sizes, all
961  // having `index` type, a variadic number of WorkGroup Attributions and
962  // a variadic number of Private Attributions. The number of WorkGroup
963  // Attributions is stored in the attr with name:
964  // LaunchOp::getNumWorkgroupAttributionsAttrName().
965  Type index = parser.getBuilder().getIndexType();
967  LaunchOp::kNumConfigRegionAttributes + 6, index);
968 
969  SmallVector<OpAsmParser::Argument> regionArguments;
970  for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
972  arg.ssaName = std::get<0>(ssaValueAndType);
973  arg.type = std::get<1>(ssaValueAndType);
974  regionArguments.push_back(arg);
975  }
976 
977  Builder &builder = parser.getBuilder();
978  // Parse workgroup memory attributions.
979  if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
980  regionArguments)))
981  return failure();
982 
983  // Store the number of operands we just parsed as the number of workgroup
984  // memory attributions.
985  unsigned numWorkgroupAttrs = regionArguments.size() -
986  LaunchOp::kNumConfigRegionAttributes -
987  (hasCluster ? 6 : 0);
988  result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
989  builder.getI64IntegerAttr(numWorkgroupAttrs));
990 
991  // Parse private memory attributions.
992  if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
993  regionArguments)))
994  return failure();
995 
996  // Introduce the body region and parse it. The region has
997  // kNumConfigRegionAttributes arguments that correspond to
998  // block/thread identifiers and grid/block sizes, all having `index` type.
999  Region *body = result.addRegion();
1000  if (parser.parseRegion(*body, regionArguments) ||
1001  parser.parseOptionalAttrDict(result.attributes))
1002  return failure();
1003 
1004  SmallVector<int32_t, 11> segmentSizes(11, 1);
1005  segmentSizes.front() = asyncDependencies.size();
1006 
1007  if (!hasCluster) {
1008  segmentSizes[7] = 0;
1009  segmentSizes[8] = 0;
1010  segmentSizes[9] = 0;
1011  }
1012  segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1013  result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1014  parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1015  return success();
1016 }
1017 
1018 /// Simplify the gpu.launch when the range of a thread or block ID is
1019 /// trivially known to be one.
1020 struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1023  PatternRewriter &rewriter) const override {
1024  // If the range implies a single value for `id`, replace `id`'s uses by
1025  // zero.
1026  Value zero;
1027  bool simplified = false;
1028  auto constPropIdUses = [&](Value id, Value size) {
1029  // Check if size is trivially one.
1030  if (!matchPattern(size, m_One()))
1031  return;
1032  if (id.getUses().empty())
1033  return;
1034  if (!simplified) {
1035  // Create a zero value the first time.
1036  OpBuilder::InsertionGuard guard(rewriter);
1037  rewriter.setInsertionPointToStart(&op.getBody().front());
1038  zero =
1039  rewriter.create<arith::ConstantIndexOp>(op.getLoc(), /*value=*/0);
1040  }
1041  rewriter.replaceAllUsesWith(id, zero);
1042  simplified = true;
1043  };
1044  constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1045  constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1046  constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1047  constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1048  constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1049  constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1050 
1051  return success(simplified);
1052  }
1053 };
1054 
1055 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1056  MLIRContext *context) {
1057  rewrites.add<FoldLaunchArguments>(context);
1058 }
1059 
1060 /// Adds a new block argument that corresponds to buffers located in
1061 /// workgroup memory.
1062 BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1063  auto attrName = getNumWorkgroupAttributionsAttrName();
1064  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1065  (*this)->setAttr(attrName,
1066  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1067  return getBody().insertArgument(
1068  LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1069 }
1070 
1071 /// Adds a new block argument that corresponds to buffers located in
1072 /// private memory.
1073 BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1074  // Buffers on the private memory always come after buffers on the workgroup
1075  // memory.
1076  return getBody().addArgument(type, loc);
1077 }
1078 
1079 //===----------------------------------------------------------------------===//
1080 // LaunchFuncOp
1081 //===----------------------------------------------------------------------===//
1082 
1083 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1084  SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1085  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1086  ValueRange kernelOperands, Type asyncTokenType,
1087  ValueRange asyncDependencies,
1088  std::optional<KernelDim3> clusterSize) {
1089  assert(kernelSymbol.getNestedReferences().size() == 1 &&
1090  "expected a symbol reference with a single nested reference");
1091  result.addOperands(asyncDependencies);
1092  if (asyncTokenType)
1093  result.types.push_back(builder.getType<AsyncTokenType>());
1094 
1095  // Add grid and block sizes as op operands, followed by the data operands.
1096  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1097  getBlockSize.y, getBlockSize.z});
1098  if (clusterSize.has_value())
1099  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1100  if (dynamicSharedMemorySize)
1101  result.addOperands(dynamicSharedMemorySize);
1102  result.addOperands(kernelOperands);
1103 
1104  Properties &prop = result.getOrAddProperties<Properties>();
1105  prop.kernel = kernelSymbol;
1106  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1107  // Initialize the segment sizes to 1.
1108  for (auto &sz : prop.operandSegmentSizes)
1109  sz = 1;
1110  prop.operandSegmentSizes[0] = asyncDependencies.size();
1111  if (!clusterSize.has_value()) {
1112  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1113  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1114  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1115  }
1116  prop.operandSegmentSizes[segmentSizesLen - 3] =
1117  dynamicSharedMemorySize ? 1 : 0;
1118  prop.operandSegmentSizes[segmentSizesLen - 2] =
1119  static_cast<int32_t>(kernelOperands.size());
1120  prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1121 }
1122 
1123 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1124  GPUFuncOp kernelFunc, KernelDim3 gridSize,
1125  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1126  ValueRange kernelOperands, Type asyncTokenType,
1127  ValueRange asyncDependencies,
1128  std::optional<KernelDim3> clusterSize) {
1129  auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1130  auto kernelSymbol =
1131  SymbolRefAttr::get(kernelModule.getNameAttr(),
1132  {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1133  build(builder, result, kernelSymbol, gridSize, getBlockSize,
1134  dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1135  asyncDependencies, clusterSize);
1136 }
1137 
1138 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1139  SymbolRefAttr kernel, KernelDim3 gridSize,
1140  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1141  ValueRange kernelOperands, Value asyncObject,
1142  std::optional<KernelDim3> clusterSize) {
1143  // Add grid and block sizes as op operands, followed by the data operands.
1144  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1145  getBlockSize.y, getBlockSize.z});
1146  if (clusterSize.has_value())
1147  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1148  if (dynamicSharedMemorySize)
1149  result.addOperands(dynamicSharedMemorySize);
1150  result.addOperands(kernelOperands);
1151  if (asyncObject)
1152  result.addOperands(asyncObject);
1153  Properties &prop = result.getOrAddProperties<Properties>();
1154  prop.kernel = kernel;
1155  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1156  // Initialize the segment sizes to 1.
1157  for (auto &sz : prop.operandSegmentSizes)
1158  sz = 1;
1159  prop.operandSegmentSizes[0] = 0;
1160  if (!clusterSize.has_value()) {
1161  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1162  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1163  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1164  }
1165  prop.operandSegmentSizes[segmentSizesLen - 3] =
1166  dynamicSharedMemorySize ? 1 : 0;
1167  prop.operandSegmentSizes[segmentSizesLen - 2] =
1168  static_cast<int32_t>(kernelOperands.size());
1169  prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1170 }
1171 
1172 StringAttr LaunchFuncOp::getKernelModuleName() {
1173  return getKernel().getRootReference();
1174 }
1175 
1176 StringAttr LaunchFuncOp::getKernelName() {
1177  return getKernel().getLeafReference();
1178 }
1179 
1180 unsigned LaunchFuncOp::getNumKernelOperands() {
1181  return getKernelOperands().size();
1182 }
1183 
1184 Value LaunchFuncOp::getKernelOperand(unsigned i) {
1185  return getKernelOperands()[i];
1186 }
1187 
1188 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1189  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1190  return KernelDim3{operands[0], operands[1], operands[2]};
1191 }
1192 
1193 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1194  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1195  return KernelDim3{operands[3], operands[4], operands[5]};
1196 }
1197 
1198 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1199  assert(hasClusterSize() &&
1200  "cluster size is not set, check hasClusterSize() first");
1201  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1202  return KernelDim3{operands[6], operands[7], operands[8]};
1203 }
1204 
1206  auto module = (*this)->getParentOfType<ModuleOp>();
1207  if (!module)
1208  return emitOpError("expected to belong to a module");
1209 
1210  if (!module->getAttrOfType<UnitAttr>(
1211  GPUDialect::getContainerModuleAttrName()))
1212  return emitOpError("expected the closest surrounding module to have the '" +
1213  GPUDialect::getContainerModuleAttrName() +
1214  "' attribute");
1215 
1216  if (hasClusterSize()) {
1217  if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1218  getClusterSizeZ().getType() != getClusterSizeX().getType())
1219  return emitOpError()
1220  << "expects types of the cluster dimensions must be the same";
1221  }
1222 
1223  return success();
1224 }
1225 
1226 static ParseResult
1228  std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1229  Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1230  if (succeeded(parser.parseOptionalColon())) {
1231  if (parser.parseType(dimTy))
1232  return failure();
1233  } else {
1234  dimTy = IndexType::get(parser.getContext());
1235  }
1236  if (clusterValue.has_value()) {
1237  clusterXTy = clusterYTy = clusterZTy = dimTy;
1238  }
1239  return success();
1240 }
1241 
1242 static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1243  Value clusterValue, Type clusterXTy,
1244  Type clusterYTy, Type clusterZTy) {
1245  if (!dimTy.isIndex())
1246  printer << ": " << dimTy;
1247 }
1248 
1250  OpAsmParser &parser,
1252  SmallVectorImpl<Type> &argTypes) {
1253  if (parser.parseOptionalKeyword("args"))
1254  return success();
1255 
1256  auto parseElement = [&]() -> ParseResult {
1257  return failure(parser.parseOperand(argNames.emplace_back()) ||
1258  parser.parseColonType(argTypes.emplace_back()));
1259  };
1260 
1262  parseElement, " in argument list");
1263 }
1264 
1266  OperandRange operands, TypeRange types) {
1267  if (operands.empty())
1268  return;
1269  printer << "args(";
1270  llvm::interleaveComma(llvm::zip(operands, types), printer,
1271  [&](const auto &pair) {
1272  printer.printOperand(std::get<0>(pair));
1273  printer << " : ";
1274  printer.printType(std::get<1>(pair));
1275  });
1276  printer << ")";
1277 }
1278 
1279 //===----------------------------------------------------------------------===//
1280 // ShuffleOp
1281 //===----------------------------------------------------------------------===//
1282 
1283 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1284  int32_t offset, int32_t width, ShuffleMode mode) {
1285  build(builder, result, value,
1286  builder.create<arith::ConstantOp>(result.location,
1287  builder.getI32IntegerAttr(offset)),
1288  builder.create<arith::ConstantOp>(result.location,
1289  builder.getI32IntegerAttr(width)),
1290  mode);
1291 }
1292 
1293 //===----------------------------------------------------------------------===//
1294 // BarrierOp
1295 //===----------------------------------------------------------------------===//
1296 
1297 namespace {
1298 
1299 /// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1300 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1301  PatternRewriter &rewriter) {
1302  if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1303  rewriter.eraseOp(op);
1304  return success();
1305  }
1306  return failure();
1307 }
1308 
1309 } // end anonymous namespace
1310 
1311 void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1312  MLIRContext *context) {
1313  results.add(eraseRedundantGpuBarrierOps);
1314 }
1315 
1316 //===----------------------------------------------------------------------===//
1317 // GPUFuncOp
1318 //===----------------------------------------------------------------------===//
1319 
1320 /// Adds a new block argument that corresponds to buffers located in
1321 /// workgroup memory.
1322 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1323  auto attrName = getNumWorkgroupAttributionsAttrName();
1324  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1325  (*this)->setAttr(attrName,
1326  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1327  return getBody().insertArgument(
1328  getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1329 }
1330 
1331 /// Adds a new block argument that corresponds to buffers located in
1332 /// private memory.
1333 BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1334  // Buffers on the private memory always come after buffers on the workgroup
1335  // memory.
1336  return getBody().addArgument(type, loc);
1337 }
1338 
1339 void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1340  StringRef name, FunctionType type,
1341  TypeRange workgroupAttributions,
1342  TypeRange privateAttributions,
1343  ArrayRef<NamedAttribute> attrs) {
1344  OpBuilder::InsertionGuard g(builder);
1345 
1347  builder.getStringAttr(name));
1348  result.addAttribute(getFunctionTypeAttrName(result.name),
1349  TypeAttr::get(type));
1350  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1351  builder.getI64IntegerAttr(workgroupAttributions.size()));
1352  result.addAttributes(attrs);
1353  Region *body = result.addRegion();
1354  Block *entryBlock = builder.createBlock(body);
1355 
1356  // TODO: Allow passing in proper locations here.
1357  for (Type argTy : type.getInputs())
1358  entryBlock->addArgument(argTy, result.location);
1359  for (Type argTy : workgroupAttributions)
1360  entryBlock->addArgument(argTy, result.location);
1361  for (Type argTy : privateAttributions)
1362  entryBlock->addArgument(argTy, result.location);
1363 }
1364 
1365 /// Parses a GPU function memory attribution.
1366 ///
1367 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1368 /// (`private` `(` ssa-id-and-type-list `)`)?
1369 ///
1370 /// Note that this function parses only one of the two similar parts, with the
1371 /// keyword provided as argument.
1372 static ParseResult
1373 parseAttributions(OpAsmParser &parser, StringRef keyword,
1375  Attribute &attributionAttrs) {
1376  // If we could not parse the keyword, just assume empty list and succeed.
1377  if (failed(parser.parseOptionalKeyword(keyword)))
1378  return success();
1379 
1380  size_t existingArgs = args.size();
1381  ParseResult result =
1383  /*allowType=*/true, /*allowAttrs=*/true);
1384  if (failed(result))
1385  return result;
1386 
1387  bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1388  [](const OpAsmParser::Argument &arg) -> bool {
1389  return arg.attrs && !arg.attrs.empty();
1390  });
1391  if (!hadAttrs) {
1392  attributionAttrs = nullptr;
1393  return result;
1394  }
1395 
1396  Builder &builder = parser.getBuilder();
1397  SmallVector<Attribute> attributionAttrsVec;
1398  for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1399  if (!argument.attrs)
1400  attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1401  else
1402  attributionAttrsVec.push_back(argument.attrs);
1403  }
1404  attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1405  return result;
1406 }
1407 
1408 /// Parses a GPU function.
1409 ///
1410 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1411 /// (`->` function-result-list)? memory-attribution `kernel`?
1412 /// function-attributes? region
1415  SmallVector<DictionaryAttr> resultAttrs;
1416  SmallVector<Type> resultTypes;
1417  bool isVariadic;
1418 
1419  // Parse the function name.
1420  StringAttr nameAttr;
1421  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1422  result.attributes))
1423  return failure();
1424 
1425  auto signatureLocation = parser.getCurrentLocation();
1427  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1428  resultAttrs)))
1429  return failure();
1430 
1431  if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1432  return parser.emitError(signatureLocation)
1433  << "gpu.func requires named arguments";
1434 
1435  // Construct the function type. More types will be added to the region, but
1436  // not to the function type.
1437  Builder &builder = parser.getBuilder();
1438 
1439  SmallVector<Type> argTypes;
1440  for (auto &arg : entryArgs)
1441  argTypes.push_back(arg.type);
1442  auto type = builder.getFunctionType(argTypes, resultTypes);
1443  result.addAttribute(getFunctionTypeAttrName(result.name),
1444  TypeAttr::get(type));
1445 
1447  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1448  getResAttrsAttrName(result.name));
1449 
1450  Attribute workgroupAttributionAttrs;
1451  // Parse workgroup memory attributions.
1452  if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1453  entryArgs, workgroupAttributionAttrs)))
1454  return failure();
1455 
1456  // Store the number of operands we just parsed as the number of workgroup
1457  // memory attributions.
1458  unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1459  result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1460  builder.getI64IntegerAttr(numWorkgroupAttrs));
1461  if (workgroupAttributionAttrs)
1462  result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1463  workgroupAttributionAttrs);
1464 
1465  Attribute privateAttributionAttrs;
1466  // Parse private memory attributions.
1467  if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1468  entryArgs, privateAttributionAttrs)))
1469  return failure();
1470  if (privateAttributionAttrs)
1471  result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1472  privateAttributionAttrs);
1473 
1474  // Parse the kernel attribute if present.
1475  if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1476  result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1477  builder.getUnitAttr());
1478 
1479  // Parse attributes.
1481  return failure();
1482 
1483  // Parse the region. If no argument names were provided, take all names
1484  // (including those of attributions) from the entry block.
1485  auto *body = result.addRegion();
1486  return parser.parseRegion(*body, entryArgs);
1487 }
1488 
1489 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
1490  ArrayRef<BlockArgument> values,
1491  ArrayAttr attributes) {
1492  if (values.empty())
1493  return;
1494 
1495  p << ' ' << keyword << '(';
1496  llvm::interleaveComma(
1497  llvm::enumerate(values), p, [&p, attributes](auto pair) {
1498  BlockArgument v = pair.value();
1499  p << v << " : " << v.getType();
1500 
1501  size_t attributionIndex = pair.index();
1502  DictionaryAttr attrs;
1503  if (attributes && attributionIndex < attributes.size())
1504  attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1505  if (attrs)
1506  p.printOptionalAttrDict(attrs.getValue());
1507  });
1508  p << ')';
1509 }
1510 
1511 void GPUFuncOp::print(OpAsmPrinter &p) {
1512  p << ' ';
1513  p.printSymbolName(getName());
1514 
1515  FunctionType type = getFunctionType();
1516  function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1517  /*isVariadic=*/false,
1518  type.getResults());
1519 
1520  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1521  getWorkgroupAttribAttrs().value_or(nullptr));
1522  printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1523  getPrivateAttribAttrs().value_or(nullptr));
1524  if (isKernel())
1525  p << ' ' << getKernelKeyword();
1526 
1528  p, *this,
1529  {getNumWorkgroupAttributionsAttrName(),
1530  GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1531  getArgAttrsAttrName(), getResAttrsAttrName(),
1532  getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1533  p << ' ';
1534  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1535 }
1536 
1537 static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1538  StringAttr attrName) {
1539  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1540  if (!allAttrs || index >= allAttrs.size())
1541  return DictionaryAttr();
1542  return llvm::cast<DictionaryAttr>(allAttrs[index]);
1543 }
1544 
1545 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1546  return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1547 }
1548 
1549 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1550  return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1551 }
1552 
1553 static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1554  DictionaryAttr value, StringAttr attrName) {
1555  MLIRContext *ctx = op.getContext();
1556  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1557  SmallVector<Attribute> elements;
1558  if (allAttrs)
1559  elements.append(allAttrs.begin(), allAttrs.end());
1560  while (elements.size() <= index)
1561  elements.push_back(DictionaryAttr::get(ctx));
1562  if (!value)
1563  elements[index] = DictionaryAttr::get(ctx);
1564  else
1565  elements[index] = value;
1566  ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1567  op->setAttr(attrName, newValue);
1568 }
1569 
1570 void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1571  DictionaryAttr value) {
1572  setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1573 }
1574 
1575 void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1576  DictionaryAttr value) {
1577  setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1578 }
1579 
1580 static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1581  StringAttr name, StringAttr attrsName) {
1582  DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1583  if (!dict)
1584  return Attribute();
1585  return dict.get(name);
1586 }
1587 
1588 Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1589  StringAttr name) {
1590  assert(index < getNumWorkgroupAttributions() &&
1591  "index must map to a workgroup attribution");
1592  return getAttributionAttr(*this, index, name,
1593  getWorkgroupAttribAttrsAttrName());
1594 }
1595 
1596 Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1597  StringAttr name) {
1598  assert(index < getNumPrivateAttributions() &&
1599  "index must map to a private attribution");
1600  return getAttributionAttr(*this, index, name,
1601  getPrivateAttribAttrsAttrName());
1602 }
1603 
1604 static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1605  Attribute value, StringAttr attrsName) {
1606  MLIRContext *ctx = op.getContext();
1608  DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1609  if (oldDict)
1610  elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1611 
1612  bool found = false;
1613  bool mustSort = true;
1614  for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1615  if (elems[i].getName() == name) {
1616  found = true;
1617  if (!value) {
1618  std::swap(elems[i], elems[elems.size() - 1]);
1619  elems.pop_back();
1620  } else {
1621  mustSort = false;
1622  elems[i] = NamedAttribute(elems[i].getName(), value);
1623  }
1624  break;
1625  }
1626  }
1627  if (!found) {
1628  if (!value)
1629  return;
1630  elems.emplace_back(name, value);
1631  }
1632  if (mustSort) {
1633  DictionaryAttr::sortInPlace(elems);
1634  }
1635  auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1636  setAttributionAttrs(op, index, newDict, attrsName);
1637 }
1638 
1639 void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1640  Attribute value) {
1641  assert(index < getNumWorkgroupAttributions() &&
1642  "index must map to a workgroup attribution");
1643  setAttributionAttr(*this, index, name, value,
1644  getWorkgroupAttribAttrsAttrName());
1645 }
1646 
1647 void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1648  Attribute value) {
1649  assert(index < getNumPrivateAttributions() &&
1650  "index must map to a private attribution");
1651  setAttributionAttr(*this, index, name, value,
1652  getPrivateAttribAttrsAttrName());
1653 }
1654 
1655 LogicalResult GPUFuncOp::verifyType() {
1656  if (isKernel() && getFunctionType().getNumResults() != 0)
1657  return emitOpError() << "expected void return type for kernel function";
1658 
1659  return success();
1660 }
1661 
1662 /// Verifies the body of the function.
1663 LogicalResult GPUFuncOp::verifyBody() {
1664  if (empty())
1665  return emitOpError() << "expected body with at least one block";
1666  unsigned numFuncArguments = getNumArguments();
1667  unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1668  unsigned numBlockArguments = front().getNumArguments();
1669  if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1670  return emitOpError() << "expected at least "
1671  << numFuncArguments + numWorkgroupAttributions
1672  << " arguments to body region";
1673 
1674  ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1675  for (unsigned i = 0; i < numFuncArguments; ++i) {
1676  Type blockArgType = front().getArgument(i).getType();
1677  if (funcArgTypes[i] != blockArgType)
1678  return emitOpError() << "expected body region argument #" << i
1679  << " to be of type " << funcArgTypes[i] << ", got "
1680  << blockArgType;
1681  }
1682 
1683  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1684  GPUDialect::getWorkgroupAddressSpace())) ||
1685  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1686  GPUDialect::getPrivateAddressSpace())))
1687  return failure();
1688 
1689  return success();
1690 }
1691 
1692 static LogicalResult verifyKnownLaunchSizeAttr(gpu::GPUFuncOp op,
1693  StringRef attrName) {
1694  auto maybeAttr = op->getAttr(attrName);
1695  if (!maybeAttr)
1696  return success();
1697  auto array = llvm::dyn_cast<DenseI32ArrayAttr>(maybeAttr);
1698  if (!array)
1699  return op.emitOpError(attrName + " must be a dense i32 array");
1700  if (array.size() != 3)
1701  return op.emitOpError(attrName + " must contain exactly 3 elements");
1702  return success();
1703 }
1704 
1706  if (failed(verifyKnownLaunchSizeAttr(*this, getKnownBlockSizeAttrName())))
1707  return failure();
1708  if (failed(verifyKnownLaunchSizeAttr(*this, getKnownGridSizeAttrName())))
1709  return failure();
1710  return success();
1711 }
1712 
1713 //===----------------------------------------------------------------------===//
1714 // ReturnOp
1715 //===----------------------------------------------------------------------===//
1716 
1718  GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1719 
1720  FunctionType funType = function.getFunctionType();
1721 
1722  if (funType.getNumResults() != getOperands().size())
1723  return emitOpError()
1724  .append("expected ", funType.getNumResults(), " result operands")
1725  .attachNote(function.getLoc())
1726  .append("return type declared here");
1727 
1728  for (const auto &pair : llvm::enumerate(
1729  llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1730  auto [type, operand] = pair.value();
1731  if (type != operand.getType())
1732  return emitOpError() << "unexpected type `" << operand.getType()
1733  << "' for operand #" << pair.index();
1734  }
1735  return success();
1736 }
1737 
1738 //===----------------------------------------------------------------------===//
1739 // GPUModuleOp
1740 //===----------------------------------------------------------------------===//
1741 
1742 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1743  StringRef name, ArrayAttr targets,
1744  Attribute offloadingHandler) {
1745  ensureTerminator(*result.addRegion(), builder, result.location);
1746 
1747  Properties &props = result.getOrAddProperties<Properties>();
1748  if (targets)
1749  props.targets = targets;
1750  props.setSymName(builder.getStringAttr(name));
1751  props.offloadingHandler = offloadingHandler;
1752 }
1753 
1754 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1755  StringRef name, ArrayRef<Attribute> targets,
1756  Attribute offloadingHandler) {
1757  build(builder, result, name,
1758  targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1759  offloadingHandler);
1760 }
1761 
1763  StringAttr nameAttr;
1764  ArrayAttr targetsAttr;
1765 
1766  if (parser.parseSymbolName(nameAttr))
1767  return failure();
1768 
1769  Properties &props = result.getOrAddProperties<Properties>();
1770  props.setSymName(nameAttr);
1771 
1772  // Parse the optional offloadingHandler
1773  if (succeeded(parser.parseOptionalLess())) {
1774  if (parser.parseAttribute(props.offloadingHandler))
1775  return failure();
1776  if (parser.parseGreater())
1777  return failure();
1778  }
1779 
1780  // Parse the optional array of target attributes.
1781  OptionalParseResult targetsAttrResult =
1782  parser.parseOptionalAttribute(targetsAttr, Type{});
1783  if (targetsAttrResult.has_value()) {
1784  if (failed(*targetsAttrResult)) {
1785  return failure();
1786  }
1787  props.targets = targetsAttr;
1788  }
1789 
1790  // If module attributes are present, parse them.
1791  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1792  return failure();
1793 
1794  // Parse the module body.
1795  auto *body = result.addRegion();
1796  if (parser.parseRegion(*body, {}))
1797  return failure();
1798 
1799  // Ensure that this module has a valid terminator.
1800  GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
1801  return success();
1802 }
1803 
1805  p << ' ';
1806  p.printSymbolName(getName());
1807 
1808  if (Attribute attr = getOffloadingHandlerAttr()) {
1809  p << " <";
1810  p.printAttribute(attr);
1811  p << ">";
1812  }
1813 
1814  if (Attribute attr = getTargetsAttr()) {
1815  p << ' ';
1816  p.printAttribute(attr);
1817  p << ' ';
1818  }
1819 
1820  p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
1821  {mlir::SymbolTable::getSymbolAttrName(),
1822  getTargetsAttrName(),
1823  getOffloadingHandlerAttrName()});
1824  p << ' ';
1825  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1826  /*printBlockTerminators=*/false);
1827 }
1828 
1829 bool GPUModuleOp::hasTarget(Attribute target) {
1830  if (ArrayAttr targets = getTargetsAttr())
1831  return llvm::count(targets.getValue(), target);
1832  return false;
1833 }
1834 
1835 void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1836  ArrayAttr &targetsAttr = getProperties().targets;
1837  SmallVector<Attribute> targetsVector(targets);
1838  targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1839 }
1840 
1841 //===----------------------------------------------------------------------===//
1842 // GPUBinaryOp
1843 //===----------------------------------------------------------------------===//
1844 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1845  Attribute offloadingHandler, ArrayAttr objects) {
1846  auto &properties = result.getOrAddProperties<Properties>();
1847  result.attributes.push_back(builder.getNamedAttr(
1848  SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1849  properties.objects = objects;
1850  if (offloadingHandler)
1851  properties.offloadingHandler = offloadingHandler;
1852  else
1853  properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
1854 }
1855 
1856 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1857  Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1858  build(builder, result, name, offloadingHandler,
1859  objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
1860 }
1861 
1863  Attribute &offloadingHandler) {
1864  if (succeeded(parser.parseOptionalLess())) {
1865  if (parser.parseAttribute(offloadingHandler))
1866  return failure();
1867  if (parser.parseGreater())
1868  return failure();
1869  }
1870  if (!offloadingHandler)
1871  offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
1872  return success();
1873 }
1874 
1876  Attribute offloadingHandler) {
1877  if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
1878  printer << '<' << offloadingHandler << '>';
1879 }
1880 
1881 //===----------------------------------------------------------------------===//
1882 // GPUMemcpyOp
1883 //===----------------------------------------------------------------------===//
1884 
1885 LogicalResult MemcpyOp::verify() {
1886  auto srcType = getSrc().getType();
1887  auto dstType = getDst().getType();
1888 
1889  if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1890  return emitOpError("arguments have incompatible element type");
1891 
1892  if (failed(verifyCompatibleShape(srcType, dstType)))
1893  return emitOpError("arguments have incompatible shape");
1894 
1895  return success();
1896 }
1897 
1898 namespace {
1899 
1900 /// Erases a common case of copy ops where a destination value is used only by
1901 /// the copy op, alloc and dealloc ops.
1902 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1903  using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1904 
1905  LogicalResult matchAndRewrite(MemcpyOp op,
1906  PatternRewriter &rewriter) const override {
1907  Value dest = op.getDst();
1908  Operation *destDefOp = dest.getDefiningOp();
1909  // `dest` must be defined by an op having Allocate memory effect in order to
1910  // perform the folding.
1911  if (!destDefOp ||
1912  !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1913  return failure();
1914  // We can erase `op` iff `dest` has no other use apart from its
1915  // use by `op` and dealloc ops.
1916  if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1917  return user != op &&
1918  !hasSingleEffect<MemoryEffects::Free>(user, dest);
1919  }))
1920  return failure();
1921  // We can perform the folding if and only if op has a single async
1922  // dependency and produces an async token as result, or if it does not have
1923  // any async dependency and does not produce any async token result.
1924  if (op.getAsyncDependencies().size() > 1 ||
1925  ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1926  (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1927  return failure();
1928  rewriter.replaceOp(op, op.getAsyncDependencies());
1929  return success();
1930  }
1931 };
1932 
1933 } // end anonymous namespace
1934 
1935 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1936  MLIRContext *context) {
1937  results.add<EraseTrivialCopyOp>(context);
1938 }
1939 
1940 //===----------------------------------------------------------------------===//
1941 // GPU_SubgroupMmaLoadMatrixOp
1942 //===----------------------------------------------------------------------===//
1943 
1944 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1945  auto srcType = getSrcMemref().getType();
1946  auto resType = getRes().getType();
1947  auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
1948  auto operand = resMatrixType.getOperand();
1949  auto srcMemrefType = llvm::cast<MemRefType>(srcType);
1950 
1951  if (!isLastMemrefDimUnitStride(srcMemrefType))
1952  return emitError(
1953  "expected source memref most minor dim must have unit stride");
1954 
1955  if (operand != "AOp" && operand != "BOp" && operand != "COp")
1956  return emitError("only AOp, BOp and COp can be loaded");
1957 
1958  return success();
1959 }
1960 
1961 //===----------------------------------------------------------------------===//
1962 // GPU_SubgroupMmaStoreMatrixOp
1963 //===----------------------------------------------------------------------===//
1964 
1965 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1966  auto srcType = getSrc().getType();
1967  auto dstType = getDstMemref().getType();
1968  auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
1969  auto dstMemrefType = llvm::cast<MemRefType>(dstType);
1970 
1971  if (!isLastMemrefDimUnitStride(dstMemrefType))
1972  return emitError(
1973  "expected destination memref most minor dim must have unit stride");
1974 
1975  if (srcMatrixType.getOperand() != "COp")
1976  return emitError(
1977  "expected the operand matrix being stored to have 'COp' operand type");
1978 
1979  return success();
1980 }
1981 
1982 //===----------------------------------------------------------------------===//
1983 // GPU_SubgroupMmaComputeOp
1984 //===----------------------------------------------------------------------===//
1985 
1986 LogicalResult SubgroupMmaComputeOp::verify() {
1987  enum OperandMap { A, B, C };
1988  SmallVector<MMAMatrixType, 3> opTypes;
1989  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
1990  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
1991  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
1992 
1993  if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
1994  opTypes[C].getOperand() != "COp")
1995  return emitError("operands must be in the order AOp, BOp, COp");
1996 
1997  ArrayRef<int64_t> aShape, bShape, cShape;
1998  aShape = opTypes[A].getShape();
1999  bShape = opTypes[B].getShape();
2000  cShape = opTypes[C].getShape();
2001 
2002  if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2003  bShape[1] != cShape[1])
2004  return emitError("operand shapes do not satisfy matmul constraints");
2005 
2006  return success();
2007 }
2008 
2009 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2010  SmallVectorImpl<::mlir::OpFoldResult> &results) {
2011  return memref::foldMemRefCast(*this);
2012 }
2013 
2014 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2015  SmallVectorImpl<::mlir::OpFoldResult> &results) {
2016  return memref::foldMemRefCast(*this);
2017 }
2018 
2019 //===----------------------------------------------------------------------===//
2020 // GPU_WaitOp
2021 //===----------------------------------------------------------------------===//
2022 
2023 namespace {
2024 
2025 /// Remove gpu.wait op use of gpu.wait op def without async dependencies.
2026 /// %t = gpu.wait async [] // No async dependencies.
2027 /// ... gpu.wait ... [%t, ...] // %t can be removed.
2028 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2029 public:
2030  using OpRewritePattern::OpRewritePattern;
2031 
2032  LogicalResult matchAndRewrite(WaitOp op,
2033  PatternRewriter &rewriter) const final {
2034  auto predicate = [](Value value) {
2035  auto waitOp = value.getDefiningOp<WaitOp>();
2036  return waitOp && waitOp->getNumOperands() == 0;
2037  };
2038  if (llvm::none_of(op.getAsyncDependencies(), predicate))
2039  return failure();
2040  SmallVector<Value> validOperands;
2041  for (Value operand : op->getOperands()) {
2042  if (predicate(operand))
2043  continue;
2044  validOperands.push_back(operand);
2045  }
2046  rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2047  return success();
2048  }
2049 };
2050 
2051 /// Simplify trivial gpu.wait ops for the following patterns.
2052 /// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2053 /// dependencies).
2054 /// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2055 /// %t0.
2056 /// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2057 /// dependencies nor return any token.
2058 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2059 public:
2060  using OpRewritePattern::OpRewritePattern;
2061 
2062  LogicalResult matchAndRewrite(WaitOp op,
2063  PatternRewriter &rewriter) const final {
2064  // Erase gpu.wait ops that neither have any async dependencies nor return
2065  // any async token.
2066  if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2067  rewriter.eraseOp(op);
2068  return success();
2069  }
2070  // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2071  if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2072  op.getAsyncToken()) {
2073  rewriter.replaceOp(op, op.getAsyncDependencies());
2074  return success();
2075  }
2076  // Erase %t = gpu.wait async ... ops, where %t has no uses.
2077  if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2078  rewriter.eraseOp(op);
2079  return success();
2080  }
2081  return failure();
2082  }
2083 };
2084 
2085 } // end anonymous namespace
2086 
2087 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2088  MLIRContext *context) {
2089  results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2090 }
2091 
2092 //===----------------------------------------------------------------------===//
2093 // GPU_AllocOp
2094 //===----------------------------------------------------------------------===//
2095 
2096 LogicalResult AllocOp::verify() {
2097  auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2098 
2099  if (static_cast<int64_t>(getDynamicSizes().size()) !=
2100  memRefType.getNumDynamicDims())
2101  return emitOpError("dimension operand count does not equal memref "
2102  "dynamic dimension count");
2103 
2104  unsigned numSymbols = 0;
2105  if (!memRefType.getLayout().isIdentity())
2106  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2107  if (getSymbolOperands().size() != numSymbols) {
2108  return emitOpError(
2109  "symbol operand count does not equal memref symbol count");
2110  }
2111 
2112  return success();
2113 }
2114 
2115 namespace {
2116 
2117 /// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2118 /// `memref::AllocOp`.
2119 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2120  using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2121 
2122  LogicalResult matchAndRewrite(memref::DimOp dimOp,
2123  PatternRewriter &rewriter) const override {
2124  std::optional<int64_t> index = dimOp.getConstantIndex();
2125  if (!index)
2126  return failure();
2127 
2128  auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2129  if (!memrefType || !memrefType.isDynamicDim(index.value()))
2130  return failure();
2131 
2132  auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2133  if (!alloc)
2134  return failure();
2135 
2136  Value substituteOp = *(alloc.getDynamicSizes().begin() +
2137  memrefType.getDynamicDimIndex(index.value()));
2138  rewriter.replaceOp(dimOp, substituteOp);
2139  return success();
2140  }
2141 };
2142 
2143 } // namespace
2144 
2145 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2146  MLIRContext *context) {
2147  results.add<SimplifyDimOfAllocOp>(context);
2148 }
2149 
2150 //===----------------------------------------------------------------------===//
2151 // GPU object attribute
2152 //===----------------------------------------------------------------------===//
2153 
2154 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2155  Attribute target, CompilationTarget format,
2156  StringAttr object, DictionaryAttr properties) {
2157  if (!target)
2158  return emitError() << "the target attribute cannot be null";
2159  if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2160  return success();
2161  return emitError() << "the target attribute must implement or promise the "
2162  "`gpu::TargetAttrInterface`";
2163 }
2164 
2165 namespace {
2166 LogicalResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2167  StringAttr &object) {
2168  std::optional<CompilationTarget> formatResult;
2169  StringRef enumKeyword;
2170  auto loc = odsParser.getCurrentLocation();
2171  if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2172  formatResult = CompilationTarget::Fatbin;
2173  if (!formatResult &&
2174  (formatResult =
2175  gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2176  odsParser.parseEqual())
2177  return odsParser.emitError(loc, "expected an equal sign");
2178  if (!formatResult)
2179  return odsParser.emitError(loc, "expected keyword for GPU object format");
2180  FailureOr<StringAttr> objectResult =
2181  FieldParser<StringAttr>::parse(odsParser);
2182  if (failed(objectResult))
2183  return odsParser.emitError(odsParser.getCurrentLocation(),
2184  "failed to parse GPU_ObjectAttr parameter "
2185  "'object' which is to be a `StringAttr`");
2186  format = *formatResult;
2187  object = *objectResult;
2188  return success();
2189 }
2190 
2191 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2192  StringAttr object) {
2193  if (format != CompilationTarget::Fatbin)
2194  odsParser << stringifyEnum(format) << " = ";
2195  odsParser << object;
2196 }
2197 } // namespace
2198 
2199 //===----------------------------------------------------------------------===//
2200 // GPU select object attribute
2201 //===----------------------------------------------------------------------===//
2202 
2203 LogicalResult
2204 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2205  Attribute target) {
2206  // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2207  if (target) {
2208  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2209  if (intAttr.getInt() < 0) {
2210  return emitError() << "the object index must be positive";
2211  }
2212  } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2213  return emitError()
2214  << "the target attribute must be a GPU Target attribute";
2215  }
2216  }
2217  return success();
2218 }
2219 
2220 //===----------------------------------------------------------------------===//
2221 // DynamicSharedMemoryOp
2222 //===----------------------------------------------------------------------===//
2223 
2224 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2225  if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2226  return emitOpError() << "must be inside an op with symbol table";
2227 
2228  MemRefType memrefType = getResultMemref().getType();
2229  // Check address space
2230  if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2231  return emitOpError() << "address space must be "
2232  << gpu::AddressSpaceAttr::getMnemonic() << "<"
2233  << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2234  }
2235  if (memrefType.hasStaticShape()) {
2236  return emitOpError() << "result memref type must be memref<?xi8, "
2237  "#gpu.address_space<workgroup>>";
2238  }
2239  return success();
2240 }
2241 
2242 //===----------------------------------------------------------------------===//
2243 // GPU target options
2244 //===----------------------------------------------------------------------===//
2245 
2246 TargetOptions::TargetOptions(
2247  StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2248  StringRef cmdOptions, CompilationTarget compilationTarget,
2249  function_ref<SymbolTable *()> getSymbolTableCallback)
2250  : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
2251  cmdOptions, compilationTarget, getSymbolTableCallback) {}
2252 
2253 TargetOptions::TargetOptions(
2254  TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2255  StringRef cmdOptions, CompilationTarget compilationTarget,
2256  function_ref<SymbolTable *()> getSymbolTableCallback)
2257  : toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
2258  cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
2259  getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
2260 
2261 TypeID TargetOptions::getTypeID() const { return typeID; }
2262 
2263 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2264 
2265 ArrayRef<std::string> TargetOptions::getLinkFiles() const { return linkFiles; }
2266 
2267 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2268 
2269 SymbolTable *TargetOptions::getSymbolTable() const {
2270  return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2271 }
2272 
2273 CompilationTarget TargetOptions::getCompilationTarget() const {
2274  return compilationTarget;
2275 }
2276 
2277 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2278  return CompilationTarget::Fatbin;
2279 }
2280 
2281 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2282 TargetOptions::tokenizeCmdOptions() const {
2283  std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2284  llvm::StringSaver stringSaver(options.first);
2285  StringRef opts = cmdOptions;
2286  // For a correct tokenization of the command line options `opts` must be
2287  // unquoted, otherwise the tokenization function returns a single string: the
2288  // unquoted `cmdOptions` -which is not the desired behavior.
2289  // Remove any quotes if they are at the beginning and end of the string:
2290  if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2291  opts.consume_front("\""), opts.consume_back("\"");
2292  if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2293  opts.consume_front("'"), opts.consume_back("'");
2294 #ifdef _WIN32
2295  llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2296  /*MarkEOLs=*/false);
2297 #else
2298  llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2299  /*MarkEOLs=*/false);
2300 #endif // _WIN32
2301  return options;
2302 }
2303 
2305 
2306 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2307 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2308 
2309 #define GET_ATTRDEF_CLASSES
2310 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2311 
2312 #define GET_OP_CLASSES
2313 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2314 
2315 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
Definition: GPUDialect.cpp:406
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
Definition: GPUDialect.cpp:570
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
Definition: GPUDialect.cpp:422
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
Definition: GPUDialect.cpp:858
static LogicalResult verifyKnownLaunchSizeAttr(gpu::GPUFuncOp op, StringRef attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)
Prints a GPU function memory attribution.
Definition: GPUDialect.cpp:456
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
Definition: GPUDialect.cpp:548
static std::string getSparseHandleKeyword(SparseHandleKind kind)
Definition: GPUDialect.cpp:222
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
Definition: GPUDialect.cpp:583
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
Definition: GPUDialect.cpp:445
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
Definition: GPUDialect.cpp:494
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
Definition: GPUDialect.cpp:809
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
Definition: GPUDialect.cpp:468
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MINUI(lhs, rhs)
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:263
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:179
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:96
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:93
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
IndexType getIndexType()
Definition: Builders.cpp:71
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition: Builders.cpp:120
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:110
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:100
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:49
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:221
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
Definition: Operation.cpp:256
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:79
bool isIndex() const
Definition: Types.cpp:56
bool isF32() const
Definition: Types.cpp:51
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:91
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:58
bool isF16() const
Definition: Types.cpp:49
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static ConcreteT get(MLIRContext *ctx, Args &&...args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:130
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Definition: GPUDialect.cpp:137
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:122
Type getElementType() const
Get elementType of a single element.
Definition: GPUDialect.cpp:141
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
Definition: GPUDialect.cpp:145
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Definition: GPUDialect.cpp:143
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
Definition: GPUDialect.cpp:128
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
Definition: GPUDialect.cpp:152
unsigned getNumDims() const
Get number of dims.
Definition: GPUDialect.cpp:135
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
Definition: GPUDialect.cpp:624
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:38