MLIR  21.0.0git
GPUDialect.cpp
Go to the documentation of this file.
1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU kernel-related dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/SymbolTable.h"
29 #include "mlir/IR/TypeUtilities.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/TypeSwitch.h"
36 #include "llvm/Support/CommandLine.h"
37 #include "llvm/Support/ErrorHandling.h"
38 #include "llvm/Support/FormatVariadic.h"
39 #include "llvm/Support/InterleavedRange.h"
40 #include "llvm/Support/StringSaver.h"
41 #include <cassert>
42 #include <numeric>
43 
44 using namespace mlir;
45 using namespace mlir::gpu;
46 
47 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
48 
49 //===----------------------------------------------------------------------===//
50 // GPU Device Mapping Attributes
51 //===----------------------------------------------------------------------===//
52 
53 int64_t GPUBlockMappingAttr::getMappingId() const {
54  return static_cast<int64_t>(getBlock());
55 }
56 
57 bool GPUBlockMappingAttr::isLinearMapping() const {
58  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
59 }
60 
61 int64_t GPUBlockMappingAttr::getRelativeIndex() const {
62  return isLinearMapping()
63  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
64  : getMappingId();
65 }
66 
67 int64_t GPUWarpgroupMappingAttr::getMappingId() const {
68  return static_cast<int64_t>(getWarpgroup());
69 }
70 
71 bool GPUWarpgroupMappingAttr::isLinearMapping() const {
72  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
73 }
74 
75 int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
76  return isLinearMapping()
77  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
78  : getMappingId();
79 }
80 
81 int64_t GPUWarpMappingAttr::getMappingId() const {
82  return static_cast<int64_t>(getWarp());
83 }
84 
85 bool GPUWarpMappingAttr::isLinearMapping() const {
86  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
87 }
88 
89 int64_t GPUWarpMappingAttr::getRelativeIndex() const {
90  return isLinearMapping()
91  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
92  : getMappingId();
93 }
94 
95 int64_t GPUThreadMappingAttr::getMappingId() const {
96  return static_cast<int64_t>(getThread());
97 }
98 
99 bool GPUThreadMappingAttr::isLinearMapping() const {
100  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
101 }
102 
103 int64_t GPUThreadMappingAttr::getRelativeIndex() const {
104  return isLinearMapping()
105  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
106  : getMappingId();
107 }
108 
109 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
110  return static_cast<int64_t>(getAddressSpace());
111 }
112 
113 bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
114  llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
115 }
116 
117 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
118  llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
119 }
120 
121 //===----------------------------------------------------------------------===//
122 // MMAMatrixType
123 //===----------------------------------------------------------------------===//
124 
126  StringRef operand) {
127  return Base::get(elementType.getContext(), shape, elementType, operand);
128 }
129 
132  ArrayRef<int64_t> shape, Type elementType,
133  StringRef operand) {
134  return Base::getChecked(emitError, elementType.getContext(), shape,
135  elementType, operand);
136 }
137 
138 unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
139 
141  return getImpl()->getShape();
142 }
143 
144 Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
145 
146 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
147 
149  return elementType.isF16() || elementType.isF32() ||
150  elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
151  elementType.isInteger(32);
152 }
153 
154 LogicalResult
156  ArrayRef<int64_t> shape, Type elementType,
157  StringRef operand) {
158  if (operand != "AOp" && operand != "BOp" && operand != "COp")
159  return emitError() << "operand expected to be one of AOp, BOp or COp";
160 
161  if (shape.size() != 2)
162  return emitError() << "MMAMatrixType must have exactly two dimensions";
163 
164  if (!MMAMatrixType::isValidElementType(elementType))
165  return emitError()
166  << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
167 
168  return success();
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // GPUDialect
173 //===----------------------------------------------------------------------===//
174 
175 bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
176  if (!memorySpace)
177  return false;
178  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
179  return gpuAttr.getValue() == getWorkgroupAddressSpace();
180  return false;
181 }
182 
183 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
184  Attribute memorySpace = type.getMemorySpace();
185  return isWorkgroupMemoryAddressSpace(memorySpace);
186 }
187 
188 bool GPUDialect::isKernel(Operation *op) {
189  UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
190  return static_cast<bool>(isKernelAttr);
191 }
192 
193 namespace {
194 /// This class defines the interface for handling inlining with gpu
195 /// operations.
196 struct GPUInlinerInterface : public DialectInlinerInterface {
198 
199  /// All gpu dialect ops can be inlined.
200  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
201  return true;
202  }
203 };
204 } // namespace
205 
206 void GPUDialect::initialize() {
207  addTypes<AsyncTokenType>();
208  addTypes<MMAMatrixType>();
209  addTypes<SparseDnTensorHandleType>();
210  addTypes<SparseSpMatHandleType>();
211  addTypes<SparseSpGEMMOpHandleType>();
212  addOperations<
213 #define GET_OP_LIST
214 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
215  >();
216  addAttributes<
217 #define GET_ATTRDEF_LIST
218 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
219  >();
220  addInterfaces<GPUInlinerInterface>();
221  declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
222  TerminatorOp>();
223  declarePromisedInterfaces<
224  ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,
225  ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,
226  SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();
227 }
228 
230  switch (kind) {
232  return "sparse.dntensor_handle";
234  return "sparse.spmat_handle";
236  return "sparse.spgemmop_handle";
237  }
238  llvm_unreachable("unknown sparse handle kind");
239  return "";
240 }
241 
243  // Parse the main keyword for the type.
244  StringRef keyword;
245  if (parser.parseKeyword(&keyword))
246  return Type();
247  MLIRContext *context = getContext();
248 
249  // Handle 'async token' types.
250  if (keyword == "async.token")
251  return AsyncTokenType::get(context);
252 
253  if (keyword == "mma_matrix") {
254  SMLoc beginLoc = parser.getNameLoc();
255 
256  // Parse '<'.
257  if (parser.parseLess())
258  return nullptr;
259 
260  // Parse the size and elementType.
261  SmallVector<int64_t> shape;
262  Type elementType;
263  if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
264  parser.parseType(elementType))
265  return nullptr;
266 
267  // Parse ','
268  if (parser.parseComma())
269  return nullptr;
270 
271  // Parse operand.
272  std::string operand;
273  if (failed(parser.parseOptionalString(&operand)))
274  return nullptr;
275 
276  // Parse '>'.
277  if (parser.parseGreater())
278  return nullptr;
279 
281  parser.getEncodedSourceLoc(beginLoc)),
282  shape, elementType, operand);
283  }
284 
286  return SparseDnTensorHandleType::get(context);
288  return SparseSpMatHandleType::get(context);
290  return SparseSpGEMMOpHandleType::get(context);
291 
292  parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
293  return Type();
294 }
295 // TODO: print refined type here. Notice that should be corresponding to the
296 // parser
297 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
298  TypeSwitch<Type>(type)
299  .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
300  .Case<SparseDnTensorHandleType>([&](Type) {
302  })
303  .Case<SparseSpMatHandleType>(
305  .Case<SparseSpGEMMOpHandleType>([&](Type) {
307  })
308  .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
309  os << "mma_matrix<";
310  auto shape = fragTy.getShape();
311  for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
312  os << *dim << 'x';
313  os << shape.back() << 'x' << fragTy.getElementType();
314  os << ", \"" << fragTy.getOperand() << "\"" << '>';
315  })
316  .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
317 }
318 
319 static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
320  NamedAttribute attr) {
321  auto array = dyn_cast<DenseI32ArrayAttr>(attr.getValue());
322  if (!array)
323  return op->emitOpError(Twine(attr.getName()) +
324  " must be a dense i32 array");
325  if (array.size() != 3)
326  return op->emitOpError(Twine(attr.getName()) +
327  " must contain exactly 3 elements");
328  return success();
329 }
330 
331 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
332  NamedAttribute attr) {
333  if (attr.getName() == getKnownBlockSizeAttrHelper().getName())
334  return verifyKnownLaunchSizeAttr(op, attr);
335  if (attr.getName() == getKnownGridSizeAttrHelper().getName())
336  return verifyKnownLaunchSizeAttr(op, attr);
337  if (!llvm::isa<UnitAttr>(attr.getValue()) ||
338  attr.getName() != getContainerModuleAttrName())
339  return success();
340 
341  auto module = dyn_cast<ModuleOp>(op);
342  if (!module)
343  return op->emitError("expected '")
344  << getContainerModuleAttrName() << "' attribute to be attached to '"
345  << ModuleOp::getOperationName() << '\'';
346 
347  auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
348  // Ignore launches that are nested more or less deep than functions in the
349  // module we are currently checking.
350  if (!launchOp->getParentOp() ||
351  launchOp->getParentOp()->getParentOp() != module)
352  return success();
353 
354  // Ignore launch ops with missing attributes here. The errors will be
355  // reported by the verifiers of those ops.
356  if (!launchOp->getAttrOfType<SymbolRefAttr>(
357  LaunchFuncOp::getKernelAttrName(launchOp->getName())))
358  return success();
359 
360  // Check that `launch_func` refers to a well-formed GPU kernel container.
361  StringAttr kernelContainerName = launchOp.getKernelModuleName();
362  Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
363  if (!kernelContainer)
364  return launchOp.emitOpError()
365  << "kernel container '" << kernelContainerName.getValue()
366  << "' is undefined";
367 
368  // If the container is a GPU binary op return success.
369  if (isa<BinaryOp>(kernelContainer))
370  return success();
371 
372  auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
373  if (!kernelModule)
374  return launchOp.emitOpError()
375  << "kernel module '" << kernelContainerName.getValue()
376  << "' is undefined";
377 
378  // Check that `launch_func` refers to a well-formed kernel function.
379  Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
380  if (!kernelFunc)
381  return launchOp.emitOpError("kernel function '")
382  << launchOp.getKernel() << "' is undefined";
383  auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
384  if (!kernelConvertedFunction) {
385  InFlightDiagnostic diag = launchOp.emitOpError()
386  << "referenced kernel '" << launchOp.getKernel()
387  << "' is not a function";
388  diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
389  return diag;
390  }
391 
392  if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
393  GPUDialect::getKernelFuncAttrName()))
394  return launchOp.emitOpError("kernel function is missing the '")
395  << GPUDialect::getKernelFuncAttrName() << "' attribute";
396 
397  // TODO: If the kernel isn't a GPU function (which happens during separate
398  // compilation), do not check type correspondence as it would require the
399  // verifier to be aware of the type conversion.
400  auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
401  if (!kernelGPUFunction)
402  return success();
403 
404  unsigned actualNumArguments = launchOp.getNumKernelOperands();
405  unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
406  if (expectedNumArguments != actualNumArguments)
407  return launchOp.emitOpError("got ")
408  << actualNumArguments << " kernel operands but expected "
409  << expectedNumArguments;
410 
411  auto functionType = kernelGPUFunction.getFunctionType();
412  for (unsigned i = 0; i < expectedNumArguments; ++i) {
413  if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
414  return launchOp.emitOpError("type of function argument ")
415  << i << " does not match";
416  }
417  }
418 
419  return success();
420  });
421 
422  return walkResult.wasInterrupted() ? failure() : success();
423 }
424 
425 /// Parses an optional list of async operands with an optional leading keyword.
426 /// (`async`)? (`[` ssa-id-list `]`)?
427 ///
428 /// This method is used by the tablegen assembly format for async ops as well.
429 static ParseResult parseAsyncDependencies(
430  OpAsmParser &parser, Type &asyncTokenType,
432  auto loc = parser.getCurrentLocation();
433  if (succeeded(parser.parseOptionalKeyword("async"))) {
434  if (parser.getNumResults() == 0)
435  return parser.emitError(loc, "needs to be named when marked 'async'");
436  asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
437  }
438  return parser.parseOperandList(asyncDependencies,
440 }
441 
442 /// Prints optional async dependencies with its leading keyword.
443 /// (`async`)? (`[` ssa-id-list `]`)?
444 // Used by the tablegen assembly format for several async ops.
446  Type asyncTokenType,
447  OperandRange asyncDependencies) {
448  if (asyncTokenType)
449  printer << "async";
450  if (asyncDependencies.empty())
451  return;
452  if (asyncTokenType)
453  printer << ' ';
454  printer << llvm::interleaved_array(asyncDependencies);
455 }
456 
457 // GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
458 /// Parses a GPU function memory attribution.
459 ///
460 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
461 /// (`private` `(` ssa-id-and-type-list `)`)?
462 ///
463 /// Note that this function parses only one of the two similar parts, with the
464 /// keyword provided as argument.
465 static ParseResult
466 parseAttributions(OpAsmParser &parser, StringRef keyword,
468  // If we could not parse the keyword, just assume empty list and succeed.
469  if (failed(parser.parseOptionalKeyword(keyword)))
470  return success();
471 
473  /*allowType=*/true);
474 }
475 
476 /// Prints a GPU function memory attribution.
477 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
478  ArrayRef<BlockArgument> values) {
479  if (values.empty())
480  return;
481 
482  auto printBlockArg = [](BlockArgument v) {
483  return llvm::formatv("{} : {}", v, v.getType());
484  };
485  p << ' ' << keyword << '('
486  << llvm::interleaved(llvm::map_range(values, printBlockArg)) << ')';
487 }
488 
489 /// Verifies a GPU function memory attribution.
490 static LogicalResult verifyAttributions(Operation *op,
491  ArrayRef<BlockArgument> attributions,
492  gpu::AddressSpace memorySpace) {
493  for (Value v : attributions) {
494  auto type = llvm::dyn_cast<MemRefType>(v.getType());
495  if (!type)
496  return op->emitOpError() << "expected memref type in attribution";
497 
498  // We can only verify the address space if it hasn't already been lowered
499  // from the AddressSpaceAttr to a target-specific numeric value.
500  auto addressSpace =
501  llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
502  if (!addressSpace)
503  continue;
504  if (addressSpace.getValue() != memorySpace)
505  return op->emitOpError()
506  << "expected memory space " << stringifyAddressSpace(memorySpace)
507  << " in attribution";
508  }
509  return success();
510 }
511 
512 //===----------------------------------------------------------------------===//
513 // AllReduceOp
514 //===----------------------------------------------------------------------===//
515 
516 static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
517  Type resType) {
518  using Kind = gpu::AllReduceOperation;
519  if (llvm::is_contained(
520  {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
521  opName)) {
522  if (!isa<FloatType>(resType))
523  return failure();
524  }
525 
526  if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
527  Kind::AND, Kind::OR, Kind::XOR},
528  opName)) {
529  if (!isa<IntegerType>(resType))
530  return failure();
531  }
532 
533  return success();
534 }
535 
536 LogicalResult gpu::AllReduceOp::verifyRegions() {
537  if (getBody().empty() != getOp().has_value())
538  return emitError("expected either an op attribute or a non-empty body");
539  if (!getBody().empty()) {
540  if (getBody().getNumArguments() != 2)
541  return emitError("expected two region arguments");
542  for (auto argument : getBody().getArguments()) {
543  if (argument.getType() != getType())
544  return emitError("incorrect region argument type");
545  }
546  unsigned yieldCount = 0;
547  for (Block &block : getBody()) {
548  if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
549  if (yield.getNumOperands() != 1)
550  return emitError("expected one gpu.yield operand");
551  if (yield.getOperand(0).getType() != getType())
552  return emitError("incorrect gpu.yield type");
553  ++yieldCount;
554  }
555  }
556  if (yieldCount == 0)
557  return emitError("expected gpu.yield op in region");
558  } else {
559  gpu::AllReduceOperation opName = *getOp();
560  if (failed(verifyReduceOpAndType(opName, getType()))) {
561  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
562  << "` reduction operation is not compatible with type "
563  << getType();
564  }
565  }
566 
567  return success();
568 }
569 
571  auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
572  if (!launchOp)
573  return false;
574 
575  Region &body = launchOp.getBody();
576  assert(!body.empty() && "Invalid region");
577 
578  // Only convert ops in gpu::launch entry block for now.
579  return op->getBlock() == &body.front();
580 }
581 
582 OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
583  if (!getUniform() && canMakeGroupOpUniform(*this)) {
584  setUniform(true);
585  return getResult();
586  }
587 
588  return nullptr;
589 }
590 
591 // TODO: Support optional custom attributes (without dialect prefix).
592 static ParseResult parseAllReduceOperation(AsmParser &parser,
593  AllReduceOperationAttr &attr) {
594  StringRef enumStr;
595  if (!parser.parseOptionalKeyword(&enumStr)) {
596  std::optional<AllReduceOperation> op =
597  gpu::symbolizeAllReduceOperation(enumStr);
598  if (!op)
599  return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
600  attr = AllReduceOperationAttr::get(parser.getContext(), *op);
601  }
602  return success();
603 }
604 
605 static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
606  AllReduceOperationAttr attr) {
607  if (attr)
608  attr.print(printer);
609 }
610 
611 //===----------------------------------------------------------------------===//
612 // SubgroupReduceOp
613 //===----------------------------------------------------------------------===//
614 
615 LogicalResult gpu::SubgroupReduceOp::verify() {
616  Type elemType = getType();
617  if (auto vecTy = dyn_cast<VectorType>(elemType)) {
618  if (vecTy.isScalable())
619  return emitOpError() << "is not compatible with scalable vector types";
620 
621  elemType = vecTy.getElementType();
622  }
623 
624  gpu::AllReduceOperation opName = getOp();
625  if (failed(verifyReduceOpAndType(opName, elemType))) {
626  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
627  << "` reduction operation is not compatible with type "
628  << getType();
629  }
630 
631  auto clusterSize = getClusterSize();
632  if (clusterSize) {
633  uint32_t size = *clusterSize;
634  if (!llvm::isPowerOf2_32(size)) {
635  return emitOpError() << "cluster size " << size
636  << " is not a power of two";
637  }
638  }
639 
640  uint32_t stride = getClusterStride();
641  if (stride != 1 && !clusterSize) {
642  return emitOpError() << "cluster stride can only be specified if cluster "
643  "size is specified";
644  }
645  if (!llvm::isPowerOf2_32(stride)) {
646  return emitOpError() << "cluster stride " << stride
647  << " is not a power of two";
648  }
649 
650  return success();
651 }
652 
653 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
654  if (getClusterSize() == 1)
655  return getValue();
656 
657  if (!getUniform() && canMakeGroupOpUniform(*this)) {
658  setUniform(true);
659  return getResult();
660  }
661 
662  return nullptr;
663 }
664 
665 //===----------------------------------------------------------------------===//
666 // AsyncOpInterface
667 //===----------------------------------------------------------------------===//
668 
670  op->insertOperands(0, {token});
671  if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
672  return;
673  auto attrName =
675  auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
676 
677  // Async dependencies is the only variadic operand.
678  if (!sizeAttr)
679  return;
680 
681  SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
682  ++sizes.front();
683  op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
684 }
685 
686 //===----------------------------------------------------------------------===//
687 // LaunchOp
688 //===----------------------------------------------------------------------===//
689 
690 void LaunchOp::build(OpBuilder &builder, OperationState &result,
691  Value gridSizeX, Value gridSizeY, Value gridSizeZ,
692  Value getBlockSizeX, Value getBlockSizeY,
693  Value getBlockSizeZ, Value dynamicSharedMemorySize,
694  Type asyncTokenType, ValueRange asyncDependencies,
695  TypeRange workgroupAttributions,
696  TypeRange privateAttributions, Value clusterSizeX,
697  Value clusterSizeY, Value clusterSizeZ) {
698  OpBuilder::InsertionGuard g(builder);
699 
700  // Add a WorkGroup attribution attribute. This attribute is required to
701  // identify private attributions in the list of block argguments.
702  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
703  builder.getI64IntegerAttr(workgroupAttributions.size()));
704 
705  // Add Op operands.
706  result.addOperands(asyncDependencies);
707  if (asyncTokenType)
708  result.types.push_back(builder.getType<AsyncTokenType>());
709 
710  // Add grid and block sizes as op operands, followed by the data operands.
711  result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
712  getBlockSizeY, getBlockSizeZ});
713  if (clusterSizeX)
714  result.addOperands(clusterSizeX);
715  if (clusterSizeY)
716  result.addOperands(clusterSizeY);
717  if (clusterSizeZ)
718  result.addOperands(clusterSizeZ);
719  if (dynamicSharedMemorySize)
720  result.addOperands(dynamicSharedMemorySize);
721 
722  // Create a kernel body region with kNumConfigRegionAttributes + N memory
723  // attributions, where the first kNumConfigRegionAttributes arguments have
724  // `index` type and the rest have the same types as the data operands.
725  Region *kernelRegion = result.addRegion();
726  Block *body = builder.createBlock(kernelRegion);
727  // TODO: Allow passing in proper locations here.
728  for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
729  body->addArgument(builder.getIndexType(), result.location);
730  // Add WorkGroup & Private attributions to the region arguments.
731  for (Type argTy : workgroupAttributions)
732  body->addArgument(argTy, result.location);
733  for (Type argTy : privateAttributions)
734  body->addArgument(argTy, result.location);
735  // Fill OperandSegmentSize Attribute.
736  SmallVector<int32_t, 11> segmentSizes(11, 1);
737  segmentSizes.front() = asyncDependencies.size();
738  segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
739  segmentSizes[7] = clusterSizeX ? 1 : 0;
740  segmentSizes[8] = clusterSizeY ? 1 : 0;
741  segmentSizes[9] = clusterSizeZ ? 1 : 0;
742  result.addAttribute(getOperandSegmentSizeAttr(),
743  builder.getDenseI32ArrayAttr(segmentSizes));
744 }
745 
746 KernelDim3 LaunchOp::getBlockIds() {
747  assert(!getBody().empty() && "LaunchOp body must not be empty.");
748  auto args = getBody().getArguments();
749  return KernelDim3{args[0], args[1], args[2]};
750 }
751 
752 KernelDim3 LaunchOp::getThreadIds() {
753  assert(!getBody().empty() && "LaunchOp body must not be empty.");
754  auto args = getBody().getArguments();
755  return KernelDim3{args[3], args[4], args[5]};
756 }
757 
758 KernelDim3 LaunchOp::getGridSize() {
759  assert(!getBody().empty() && "LaunchOp body must not be empty.");
760  auto args = getBody().getArguments();
761  return KernelDim3{args[6], args[7], args[8]};
762 }
763 
765  assert(!getBody().empty() && "LaunchOp body must not be empty.");
766  auto args = getBody().getArguments();
767  return KernelDim3{args[9], args[10], args[11]};
768 }
769 
770 std::optional<KernelDim3> LaunchOp::getClusterIds() {
771  assert(!getBody().empty() && "LaunchOp body must not be empty.");
772  if (!hasClusterSize())
773  return std::nullopt;
774  auto args = getBody().getArguments();
775  return KernelDim3{args[12], args[13], args[14]};
776 }
777 
778 std::optional<KernelDim3> LaunchOp::getClusterSize() {
779  assert(!getBody().empty() && "LaunchOp body must not be empty.");
780  if (!hasClusterSize())
781  return std::nullopt;
782  auto args = getBody().getArguments();
783  return KernelDim3{args[15], args[16], args[17]};
784 }
785 
786 KernelDim3 LaunchOp::getGridSizeOperandValues() {
787  auto operands = getOperands().drop_front(getAsyncDependencies().size());
788  return KernelDim3{operands[0], operands[1], operands[2]};
789 }
790 
791 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
792  auto operands = getOperands().drop_front(getAsyncDependencies().size());
793  return KernelDim3{operands[3], operands[4], operands[5]};
794 }
795 
796 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
797  auto operands = getOperands().drop_front(getAsyncDependencies().size());
798  if (!hasClusterSize())
799  return std::nullopt;
800  return KernelDim3{operands[6], operands[7], operands[8]};
801 }
802 
803 LogicalResult LaunchOp::verify() {
804  if (!(hasClusterSize()) &&
805  (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
806  return emitOpError() << "cluster size must be all present";
807  return success();
808 }
809 
810 LogicalResult LaunchOp::verifyRegions() {
811  // Kernel launch takes kNumConfigOperands leading operands for grid/block
812  // sizes and transforms them into kNumConfigRegionAttributes region arguments
813  // for block/thread identifiers and grid/block sizes.
814  if (!getBody().empty()) {
815  if (getBody().getNumArguments() <
816  kNumConfigRegionAttributes + getNumWorkgroupAttributions())
817  return emitOpError("unexpected number of region arguments");
818  }
819 
820  // Verify Attributions Address Spaces.
821  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
822  GPUDialect::getWorkgroupAddressSpace())) ||
823  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
824  GPUDialect::getPrivateAddressSpace())))
825  return failure();
826 
827  // Block terminators without successors are expected to exit the kernel region
828  // and must be `gpu.terminator`.
829  for (Block &block : getBody()) {
830  if (block.empty())
831  continue;
832  if (block.back().getNumSuccessors() != 0)
833  continue;
834  if (!isa<gpu::TerminatorOp>(&block.back())) {
835  return block.back()
836  .emitError()
837  .append("expected '", gpu::TerminatorOp::getOperationName(),
838  "' or a terminator with successors")
839  .attachNote(getLoc())
840  .append("in '", LaunchOp::getOperationName(), "' body region");
841  }
842  }
843 
844  if (getNumResults() == 0 && getAsyncToken())
845  return emitOpError("needs to be named when async keyword is specified");
846 
847  return success();
848 }
849 
850 // Pretty-print the kernel grid/block size assignment as
851 // (%iter-x, %iter-y, %iter-z) in
852 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
853 // where %size-* and %iter-* will correspond to the body region arguments.
855  KernelDim3 operands, KernelDim3 ids) {
856  p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
857  p << size.x << " = " << operands.x << ", ";
858  p << size.y << " = " << operands.y << ", ";
859  p << size.z << " = " << operands.z << ')';
860 }
861 
862 void LaunchOp::print(OpAsmPrinter &p) {
863  if (getAsyncToken()) {
864  p << " async";
865  if (!getAsyncDependencies().empty())
866  p << " [" << getAsyncDependencies() << ']';
867  }
868  // Print the launch configuration.
869  if (hasClusterSize()) {
870  p << ' ' << getClustersKeyword();
871  printSizeAssignment(p, getClusterSize().value(),
872  getClusterSizeOperandValues().value(),
873  getClusterIds().value());
874  }
875  p << ' ' << getBlocksKeyword();
876  printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
877  getBlockIds());
878  p << ' ' << getThreadsKeyword();
879  printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
880  getThreadIds());
881  if (getDynamicSharedMemorySize())
882  p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
883  << getDynamicSharedMemorySize();
884 
885  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
886  printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
887 
888  p << ' ';
889 
890  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
891  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
892  LaunchOp::getOperandSegmentSizeAttr(),
893  getNumWorkgroupAttributionsAttrName()});
894 }
895 
896 // Parse the size assignment blocks for blocks and threads. These have the form
897 // (%region_arg, %region_arg, %region_arg) in
898 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
899 // where %region_arg are percent-identifiers for the region arguments to be
900 // introduced further (SSA defs), and %operand are percent-identifiers for the
901 // SSA value uses.
902 static ParseResult
907  assert(indices.size() == 3 && "space for three indices expected");
910  /*allowResultNumber=*/false) ||
911  parser.parseKeyword("in") || parser.parseLParen())
912  return failure();
913  std::move(args.begin(), args.end(), indices.begin());
914 
915  for (int i = 0; i < 3; ++i) {
916  if (i != 0 && parser.parseComma())
917  return failure();
918  if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
919  parser.parseEqual() || parser.parseOperand(sizes[i]))
920  return failure();
921  }
922 
923  return parser.parseRParen();
924 }
925 
926 /// Parses a Launch operation.
927 /// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
928 /// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
929 /// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
930 /// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
931 /// memory-attribution
932 /// region attr-dict?
933 /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
934 ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
935  // Sizes of the grid and block.
937  sizes(LaunchOp::kNumConfigOperands);
938 
939  // Region arguments to be created.
941  LaunchOp::kNumConfigRegionAttributes);
942 
943  // Parse optional async dependencies.
945  Type asyncTokenType;
946  if (failed(
947  parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
948  parser.resolveOperands(asyncDependencies, asyncTokenType,
949  result.operands))
950  return failure();
951  if (parser.getNumResults() > 0)
952  result.types.push_back(asyncTokenType);
953 
954  bool hasCluster = false;
955  if (succeeded(
956  parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) {
957  hasCluster = true;
958  sizes.resize(9);
959  regionArgs.resize(18);
960  }
962  MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
963 
964  // Last three segment assigns the cluster size. In the region argument
965  // list, this is last 6 arguments.
966  if (hasCluster) {
967  if (parseSizeAssignment(parser, sizesRef.drop_front(6),
968  regionArgsRef.slice(15, 3),
969  regionArgsRef.slice(12, 3)))
970  return failure();
971  }
972  // Parse the size assignment segments: the first segment assigns grid sizes
973  // and defines values for block identifiers; the second segment assigns block
974  // sizes and defines values for thread identifiers. In the region argument
975  // list, identifiers precede sizes, and block-related values precede
976  // thread-related values.
977  if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
978  parseSizeAssignment(parser, sizesRef.take_front(3),
979  regionArgsRef.slice(6, 3),
980  regionArgsRef.slice(0, 3)) ||
981  parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
982  parseSizeAssignment(parser, sizesRef.drop_front(3),
983  regionArgsRef.slice(9, 3),
984  regionArgsRef.slice(3, 3)) ||
985  parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
986  result.operands))
987  return failure();
988 
989  OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
990  bool hasDynamicSharedMemorySize = false;
991  if (!parser.parseOptionalKeyword(
992  LaunchOp::getDynamicSharedMemorySizeKeyword())) {
993  hasDynamicSharedMemorySize = true;
994  if (parser.parseOperand(dynamicSharedMemorySize) ||
995  parser.resolveOperand(dynamicSharedMemorySize,
996  parser.getBuilder().getI32Type(),
997  result.operands))
998  return failure();
999  }
1000 
1001  // Create the region arguments, it has kNumConfigRegionAttributes arguments
1002  // that correspond to block/thread identifiers and grid/block sizes, all
1003  // having `index` type, a variadic number of WorkGroup Attributions and
1004  // a variadic number of Private Attributions. The number of WorkGroup
1005  // Attributions is stored in the attr with name:
1006  // LaunchOp::getNumWorkgroupAttributionsAttrName().
1007  Type index = parser.getBuilder().getIndexType();
1009  LaunchOp::kNumConfigRegionAttributes + 6, index);
1010 
1011  SmallVector<OpAsmParser::Argument> regionArguments;
1012  for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1014  arg.ssaName = std::get<0>(ssaValueAndType);
1015  arg.type = std::get<1>(ssaValueAndType);
1016  regionArguments.push_back(arg);
1017  }
1018 
1019  Builder &builder = parser.getBuilder();
1020  // Parse workgroup memory attributions.
1021  if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
1022  regionArguments)))
1023  return failure();
1024 
1025  // Store the number of operands we just parsed as the number of workgroup
1026  // memory attributions.
1027  unsigned numWorkgroupAttrs = regionArguments.size() -
1028  LaunchOp::kNumConfigRegionAttributes -
1029  (hasCluster ? 6 : 0);
1030  result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1031  builder.getI64IntegerAttr(numWorkgroupAttrs));
1032 
1033  // Parse private memory attributions.
1034  if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
1035  regionArguments)))
1036  return failure();
1037 
1038  // Introduce the body region and parse it. The region has
1039  // kNumConfigRegionAttributes arguments that correspond to
1040  // block/thread identifiers and grid/block sizes, all having `index` type.
1041  Region *body = result.addRegion();
1042  if (parser.parseRegion(*body, regionArguments) ||
1043  parser.parseOptionalAttrDict(result.attributes))
1044  return failure();
1045 
1046  SmallVector<int32_t, 11> segmentSizes(11, 1);
1047  segmentSizes.front() = asyncDependencies.size();
1048 
1049  if (!hasCluster) {
1050  segmentSizes[7] = 0;
1051  segmentSizes[8] = 0;
1052  segmentSizes[9] = 0;
1053  }
1054  segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1055  result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1056  parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1057  return success();
1058 }
1059 
1060 /// Simplify the gpu.launch when the range of a thread or block ID is
1061 /// trivially known to be one.
1062 struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1064  LogicalResult matchAndRewrite(LaunchOp op,
1065  PatternRewriter &rewriter) const override {
1066  // If the range implies a single value for `id`, replace `id`'s uses by
1067  // zero.
1068  Value zero;
1069  bool simplified = false;
1070  auto constPropIdUses = [&](Value id, Value size) {
1071  // Check if size is trivially one.
1072  if (!matchPattern(size, m_One()))
1073  return;
1074  if (id.getUses().empty())
1075  return;
1076  if (!simplified) {
1077  // Create a zero value the first time.
1078  OpBuilder::InsertionGuard guard(rewriter);
1079  rewriter.setInsertionPointToStart(&op.getBody().front());
1080  zero =
1081  rewriter.create<arith::ConstantIndexOp>(op.getLoc(), /*value=*/0);
1082  }
1083  rewriter.replaceAllUsesWith(id, zero);
1084  simplified = true;
1085  };
1086  constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1087  constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1088  constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1089  constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1090  constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1091  constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1092 
1093  return success(simplified);
1094  }
1095 };
1096 
1097 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1098  MLIRContext *context) {
1099  rewrites.add<FoldLaunchArguments>(context);
1100 }
1101 
1102 /// Adds a new block argument that corresponds to buffers located in
1103 /// workgroup memory.
1104 BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1105  auto attrName = getNumWorkgroupAttributionsAttrName();
1106  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1107  (*this)->setAttr(attrName,
1108  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1109  return getBody().insertArgument(
1110  LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1111 }
1112 
1113 /// Adds a new block argument that corresponds to buffers located in
1114 /// private memory.
1115 BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1116  // Buffers on the private memory always come after buffers on the workgroup
1117  // memory.
1118  return getBody().addArgument(type, loc);
1119 }
1120 
1121 //===----------------------------------------------------------------------===//
1122 // LaunchFuncOp
1123 //===----------------------------------------------------------------------===//
1124 
1125 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1126  SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1127  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1128  ValueRange kernelOperands, Type asyncTokenType,
1129  ValueRange asyncDependencies,
1130  std::optional<KernelDim3> clusterSize) {
1131  assert(kernelSymbol.getNestedReferences().size() == 1 &&
1132  "expected a symbol reference with a single nested reference");
1133  result.addOperands(asyncDependencies);
1134  if (asyncTokenType)
1135  result.types.push_back(builder.getType<AsyncTokenType>());
1136 
1137  // Add grid and block sizes as op operands, followed by the data operands.
1138  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1139  getBlockSize.y, getBlockSize.z});
1140  if (clusterSize.has_value())
1141  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1142  if (dynamicSharedMemorySize)
1143  result.addOperands(dynamicSharedMemorySize);
1144  result.addOperands(kernelOperands);
1145 
1146  Properties &prop = result.getOrAddProperties<Properties>();
1147  prop.kernel = kernelSymbol;
1148  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1149  // Initialize the segment sizes to 1.
1150  for (auto &sz : prop.operandSegmentSizes)
1151  sz = 1;
1152  prop.operandSegmentSizes[0] = asyncDependencies.size();
1153  if (!clusterSize.has_value()) {
1154  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1155  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1156  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1157  }
1158  prop.operandSegmentSizes[segmentSizesLen - 3] =
1159  dynamicSharedMemorySize ? 1 : 0;
1160  prop.operandSegmentSizes[segmentSizesLen - 2] =
1161  static_cast<int32_t>(kernelOperands.size());
1162  prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1163 }
1164 
1165 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1166  GPUFuncOp kernelFunc, KernelDim3 gridSize,
1167  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1168  ValueRange kernelOperands, Type asyncTokenType,
1169  ValueRange asyncDependencies,
1170  std::optional<KernelDim3> clusterSize) {
1171  auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1172  auto kernelSymbol =
1173  SymbolRefAttr::get(kernelModule.getNameAttr(),
1174  {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1175  build(builder, result, kernelSymbol, gridSize, getBlockSize,
1176  dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1177  asyncDependencies, clusterSize);
1178 }
1179 
1180 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1181  SymbolRefAttr kernel, KernelDim3 gridSize,
1182  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1183  ValueRange kernelOperands, Value asyncObject,
1184  std::optional<KernelDim3> clusterSize) {
1185  // Add grid and block sizes as op operands, followed by the data operands.
1186  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1187  getBlockSize.y, getBlockSize.z});
1188  if (clusterSize.has_value())
1189  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1190  if (dynamicSharedMemorySize)
1191  result.addOperands(dynamicSharedMemorySize);
1192  result.addOperands(kernelOperands);
1193  if (asyncObject)
1194  result.addOperands(asyncObject);
1195  Properties &prop = result.getOrAddProperties<Properties>();
1196  prop.kernel = kernel;
1197  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1198  // Initialize the segment sizes to 1.
1199  for (auto &sz : prop.operandSegmentSizes)
1200  sz = 1;
1201  prop.operandSegmentSizes[0] = 0;
1202  if (!clusterSize.has_value()) {
1203  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1204  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1205  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1206  }
1207  prop.operandSegmentSizes[segmentSizesLen - 3] =
1208  dynamicSharedMemorySize ? 1 : 0;
1209  prop.operandSegmentSizes[segmentSizesLen - 2] =
1210  static_cast<int32_t>(kernelOperands.size());
1211  prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1212 }
1213 
1214 StringAttr LaunchFuncOp::getKernelModuleName() {
1215  return getKernel().getRootReference();
1216 }
1217 
1218 StringAttr LaunchFuncOp::getKernelName() {
1219  return getKernel().getLeafReference();
1220 }
1221 
1222 unsigned LaunchFuncOp::getNumKernelOperands() {
1223  return getKernelOperands().size();
1224 }
1225 
1226 Value LaunchFuncOp::getKernelOperand(unsigned i) {
1227  return getKernelOperands()[i];
1228 }
1229 
1230 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1231  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1232  return KernelDim3{operands[0], operands[1], operands[2]};
1233 }
1234 
1235 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1236  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1237  return KernelDim3{operands[3], operands[4], operands[5]};
1238 }
1239 
1240 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1241  assert(hasClusterSize() &&
1242  "cluster size is not set, check hasClusterSize() first");
1243  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1244  return KernelDim3{operands[6], operands[7], operands[8]};
1245 }
1246 
1247 LogicalResult LaunchFuncOp::verify() {
1248  auto module = (*this)->getParentOfType<ModuleOp>();
1249  if (!module)
1250  return emitOpError("expected to belong to a module");
1251 
1252  if (!module->getAttrOfType<UnitAttr>(
1253  GPUDialect::getContainerModuleAttrName()))
1254  return emitOpError("expected the closest surrounding module to have the '" +
1255  GPUDialect::getContainerModuleAttrName() +
1256  "' attribute");
1257 
1258  if (hasClusterSize()) {
1259  if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1260  getClusterSizeZ().getType() != getClusterSizeX().getType())
1261  return emitOpError()
1262  << "expects types of the cluster dimensions must be the same";
1263  }
1264 
1265  return success();
1266 }
1267 
1268 static ParseResult
1270  std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1271  Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1272  if (succeeded(parser.parseOptionalColon())) {
1273  if (parser.parseType(dimTy))
1274  return failure();
1275  } else {
1276  dimTy = IndexType::get(parser.getContext());
1277  }
1278  if (clusterValue.has_value()) {
1279  clusterXTy = clusterYTy = clusterZTy = dimTy;
1280  }
1281  return success();
1282 }
1283 
1284 static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1285  Value clusterValue, Type clusterXTy,
1286  Type clusterYTy, Type clusterZTy) {
1287  if (!dimTy.isIndex())
1288  printer << ": " << dimTy;
1289 }
1290 
1291 static ParseResult parseLaunchFuncOperands(
1292  OpAsmParser &parser,
1294  SmallVectorImpl<Type> &argTypes) {
1295  if (parser.parseOptionalKeyword("args"))
1296  return success();
1297 
1298  auto parseElement = [&]() -> ParseResult {
1299  return failure(parser.parseOperand(argNames.emplace_back()) ||
1300  parser.parseColonType(argTypes.emplace_back()));
1301  };
1302 
1304  parseElement, " in argument list");
1305 }
1306 
1308  OperandRange operands, TypeRange types) {
1309  if (operands.empty())
1310  return;
1311  printer << "args(";
1312  llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1313  [&](const auto &pair) {
1314  auto [operand, type] = pair;
1315  printer << operand << " : " << type;
1316  });
1317  printer << ")";
1318 }
1319 
1320 //===----------------------------------------------------------------------===//
1321 // ShuffleOp
1322 //===----------------------------------------------------------------------===//
1323 
1324 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1325  int32_t offset, int32_t width, ShuffleMode mode) {
1326  build(builder, result, value,
1327  builder.create<arith::ConstantOp>(result.location,
1328  builder.getI32IntegerAttr(offset)),
1329  builder.create<arith::ConstantOp>(result.location,
1330  builder.getI32IntegerAttr(width)),
1331  mode);
1332 }
1333 
1334 //===----------------------------------------------------------------------===//
1335 // BarrierOp
1336 //===----------------------------------------------------------------------===//
1337 
1338 namespace {
1339 
1340 /// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1341 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1342  PatternRewriter &rewriter) {
1343  if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1344  rewriter.eraseOp(op);
1345  return success();
1346  }
1347  return failure();
1348 }
1349 
1350 } // end anonymous namespace
1351 
1352 void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1353  MLIRContext *context) {
1354  results.add(eraseRedundantGpuBarrierOps);
1355 }
1356 
1357 //===----------------------------------------------------------------------===//
1358 // GPUFuncOp
1359 //===----------------------------------------------------------------------===//
1360 
1361 /// Adds a new block argument that corresponds to buffers located in
1362 /// workgroup memory.
1363 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1364  auto attrName = getNumWorkgroupAttributionsAttrName();
1365  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1366  (*this)->setAttr(attrName,
1367  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1368  return getBody().insertArgument(
1369  getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1370 }
1371 
1372 /// Adds a new block argument that corresponds to buffers located in
1373 /// private memory.
1374 BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1375  // Buffers on the private memory always come after buffers on the workgroup
1376  // memory.
1377  return getBody().addArgument(type, loc);
1378 }
1379 
1380 void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1381  StringRef name, FunctionType type,
1382  TypeRange workgroupAttributions,
1383  TypeRange privateAttributions,
1384  ArrayRef<NamedAttribute> attrs) {
1385  OpBuilder::InsertionGuard g(builder);
1386 
1388  builder.getStringAttr(name));
1389  result.addAttribute(getFunctionTypeAttrName(result.name),
1390  TypeAttr::get(type));
1391  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1392  builder.getI64IntegerAttr(workgroupAttributions.size()));
1393  result.addAttributes(attrs);
1394  Region *body = result.addRegion();
1395  Block *entryBlock = builder.createBlock(body);
1396 
1397  // TODO: Allow passing in proper locations here.
1398  for (Type argTy : type.getInputs())
1399  entryBlock->addArgument(argTy, result.location);
1400  for (Type argTy : workgroupAttributions)
1401  entryBlock->addArgument(argTy, result.location);
1402  for (Type argTy : privateAttributions)
1403  entryBlock->addArgument(argTy, result.location);
1404 }
1405 
1406 /// Parses a GPU function memory attribution.
1407 ///
1408 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1409 /// (`private` `(` ssa-id-and-type-list `)`)?
1410 ///
1411 /// Note that this function parses only one of the two similar parts, with the
1412 /// keyword provided as argument.
1413 static ParseResult
1414 parseAttributions(OpAsmParser &parser, StringRef keyword,
1416  Attribute &attributionAttrs) {
1417  // If we could not parse the keyword, just assume empty list and succeed.
1418  if (failed(parser.parseOptionalKeyword(keyword)))
1419  return success();
1420 
1421  size_t existingArgs = args.size();
1422  ParseResult result =
1424  /*allowType=*/true, /*allowAttrs=*/true);
1425  if (failed(result))
1426  return result;
1427 
1428  bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1429  [](const OpAsmParser::Argument &arg) -> bool {
1430  return arg.attrs && !arg.attrs.empty();
1431  });
1432  if (!hadAttrs) {
1433  attributionAttrs = nullptr;
1434  return result;
1435  }
1436 
1437  Builder &builder = parser.getBuilder();
1438  SmallVector<Attribute> attributionAttrsVec;
1439  for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1440  if (!argument.attrs)
1441  attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1442  else
1443  attributionAttrsVec.push_back(argument.attrs);
1444  }
1445  attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1446  return result;
1447 }
1448 
1449 /// Parses a GPU function.
1450 ///
1451 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1452 /// (`->` function-result-list)? memory-attribution `kernel`?
1453 /// function-attributes? region
1454 ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
1456  SmallVector<DictionaryAttr> resultAttrs;
1457  SmallVector<Type> resultTypes;
1458  bool isVariadic;
1459 
1460  // Parse the function name.
1461  StringAttr nameAttr;
1462  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1463  result.attributes))
1464  return failure();
1465 
1466  auto signatureLocation = parser.getCurrentLocation();
1468  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1469  resultAttrs)))
1470  return failure();
1471 
1472  if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1473  return parser.emitError(signatureLocation)
1474  << "gpu.func requires named arguments";
1475 
1476  // Construct the function type. More types will be added to the region, but
1477  // not to the function type.
1478  Builder &builder = parser.getBuilder();
1479 
1480  SmallVector<Type> argTypes;
1481  for (auto &arg : entryArgs)
1482  argTypes.push_back(arg.type);
1483  auto type = builder.getFunctionType(argTypes, resultTypes);
1484  result.addAttribute(getFunctionTypeAttrName(result.name),
1485  TypeAttr::get(type));
1486 
1488  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1489  getResAttrsAttrName(result.name));
1490 
1491  Attribute workgroupAttributionAttrs;
1492  // Parse workgroup memory attributions.
1493  if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1494  entryArgs, workgroupAttributionAttrs)))
1495  return failure();
1496 
1497  // Store the number of operands we just parsed as the number of workgroup
1498  // memory attributions.
1499  unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1500  result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1501  builder.getI64IntegerAttr(numWorkgroupAttrs));
1502  if (workgroupAttributionAttrs)
1503  result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1504  workgroupAttributionAttrs);
1505 
1506  Attribute privateAttributionAttrs;
1507  // Parse private memory attributions.
1508  if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1509  entryArgs, privateAttributionAttrs)))
1510  return failure();
1511  if (privateAttributionAttrs)
1512  result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1513  privateAttributionAttrs);
1514 
1515  // Parse the kernel attribute if present.
1516  if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1517  result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1518  builder.getUnitAttr());
1519 
1520  // Parse attributes.
1521  if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1522  return failure();
1523 
1524  // Parse the region. If no argument names were provided, take all names
1525  // (including those of attributions) from the entry block.
1526  auto *body = result.addRegion();
1527  return parser.parseRegion(*body, entryArgs);
1528 }
1529 
1530 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
1531  ArrayRef<BlockArgument> values,
1532  ArrayAttr attributes) {
1533  if (values.empty())
1534  return;
1535 
1536  p << ' ' << keyword << '(';
1537  llvm::interleaveComma(
1538  llvm::enumerate(values), p, [&p, attributes](auto pair) {
1539  BlockArgument v = pair.value();
1540  p << v << " : " << v.getType();
1541 
1542  size_t attributionIndex = pair.index();
1543  DictionaryAttr attrs;
1544  if (attributes && attributionIndex < attributes.size())
1545  attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1546  if (attrs)
1547  p.printOptionalAttrDict(attrs.getValue());
1548  });
1549  p << ')';
1550 }
1551 
1552 void GPUFuncOp::print(OpAsmPrinter &p) {
1553  p << ' ';
1554  p.printSymbolName(getName());
1555 
1556  FunctionType type = getFunctionType();
1557  function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1558  /*isVariadic=*/false,
1559  type.getResults());
1560 
1561  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1562  getWorkgroupAttribAttrs().value_or(nullptr));
1563  printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1564  getPrivateAttribAttrs().value_or(nullptr));
1565  if (isKernel())
1566  p << ' ' << getKernelKeyword();
1567 
1569  p, *this,
1570  {getNumWorkgroupAttributionsAttrName(),
1571  GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1572  getArgAttrsAttrName(), getResAttrsAttrName(),
1573  getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1574  p << ' ';
1575  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1576 }
1577 
1578 static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1579  StringAttr attrName) {
1580  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1581  if (!allAttrs || index >= allAttrs.size())
1582  return DictionaryAttr();
1583  return llvm::cast<DictionaryAttr>(allAttrs[index]);
1584 }
1585 
1586 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1587  return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1588 }
1589 
1590 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1591  return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1592 }
1593 
1594 static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1595  DictionaryAttr value, StringAttr attrName) {
1596  MLIRContext *ctx = op.getContext();
1597  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1598  SmallVector<Attribute> elements;
1599  if (allAttrs)
1600  elements.append(allAttrs.begin(), allAttrs.end());
1601  while (elements.size() <= index)
1602  elements.push_back(DictionaryAttr::get(ctx));
1603  if (!value)
1604  elements[index] = DictionaryAttr::get(ctx);
1605  else
1606  elements[index] = value;
1607  ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1608  op->setAttr(attrName, newValue);
1609 }
1610 
1611 void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1612  DictionaryAttr value) {
1613  setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1614 }
1615 
1616 void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1617  DictionaryAttr value) {
1618  setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1619 }
1620 
1621 static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1622  StringAttr name, StringAttr attrsName) {
1623  DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1624  if (!dict)
1625  return Attribute();
1626  return dict.get(name);
1627 }
1628 
1629 Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1630  StringAttr name) {
1631  assert(index < getNumWorkgroupAttributions() &&
1632  "index must map to a workgroup attribution");
1633  return getAttributionAttr(*this, index, name,
1634  getWorkgroupAttribAttrsAttrName());
1635 }
1636 
1637 Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1638  StringAttr name) {
1639  assert(index < getNumPrivateAttributions() &&
1640  "index must map to a private attribution");
1641  return getAttributionAttr(*this, index, name,
1642  getPrivateAttribAttrsAttrName());
1643 }
1644 
1645 static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1646  Attribute value, StringAttr attrsName) {
1647  MLIRContext *ctx = op.getContext();
1649  DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1650  if (oldDict)
1651  elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1652 
1653  bool found = false;
1654  bool mustSort = true;
1655  for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1656  if (elems[i].getName() == name) {
1657  found = true;
1658  if (!value) {
1659  std::swap(elems[i], elems[elems.size() - 1]);
1660  elems.pop_back();
1661  } else {
1662  mustSort = false;
1663  elems[i] = NamedAttribute(elems[i].getName(), value);
1664  }
1665  break;
1666  }
1667  }
1668  if (!found) {
1669  if (!value)
1670  return;
1671  elems.emplace_back(name, value);
1672  }
1673  if (mustSort) {
1674  DictionaryAttr::sortInPlace(elems);
1675  }
1676  auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1677  setAttributionAttrs(op, index, newDict, attrsName);
1678 }
1679 
1680 void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1681  Attribute value) {
1682  assert(index < getNumWorkgroupAttributions() &&
1683  "index must map to a workgroup attribution");
1684  setAttributionAttr(*this, index, name, value,
1685  getWorkgroupAttribAttrsAttrName());
1686 }
1687 
1688 void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1689  Attribute value) {
1690  assert(index < getNumPrivateAttributions() &&
1691  "index must map to a private attribution");
1692  setAttributionAttr(*this, index, name, value,
1693  getPrivateAttribAttrsAttrName());
1694 }
1695 
1696 LogicalResult GPUFuncOp::verifyType() {
1697  if (isKernel() && getFunctionType().getNumResults() != 0)
1698  return emitOpError() << "expected void return type for kernel function";
1699 
1700  return success();
1701 }
1702 
1703 /// Verifies the body of the function.
1704 LogicalResult GPUFuncOp::verifyBody() {
1705  if (empty())
1706  return emitOpError() << "expected body with at least one block";
1707  unsigned numFuncArguments = getNumArguments();
1708  unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1709  unsigned numBlockArguments = front().getNumArguments();
1710  if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1711  return emitOpError() << "expected at least "
1712  << numFuncArguments + numWorkgroupAttributions
1713  << " arguments to body region";
1714 
1715  ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1716  for (unsigned i = 0; i < numFuncArguments; ++i) {
1717  Type blockArgType = front().getArgument(i).getType();
1718  if (funcArgTypes[i] != blockArgType)
1719  return emitOpError() << "expected body region argument #" << i
1720  << " to be of type " << funcArgTypes[i] << ", got "
1721  << blockArgType;
1722  }
1723 
1724  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1725  GPUDialect::getWorkgroupAddressSpace())) ||
1726  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1727  GPUDialect::getPrivateAddressSpace())))
1728  return failure();
1729 
1730  return success();
1731 }
1732 
1733 //===----------------------------------------------------------------------===//
1734 // ReturnOp
1735 //===----------------------------------------------------------------------===//
1736 
1737 LogicalResult gpu::ReturnOp::verify() {
1738  GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1739 
1740  FunctionType funType = function.getFunctionType();
1741 
1742  if (funType.getNumResults() != getOperands().size())
1743  return emitOpError()
1744  .append("expected ", funType.getNumResults(), " result operands")
1745  .attachNote(function.getLoc())
1746  .append("return type declared here");
1747 
1748  for (const auto &pair : llvm::enumerate(
1749  llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1750  auto [type, operand] = pair.value();
1751  if (type != operand.getType())
1752  return emitOpError() << "unexpected type `" << operand.getType()
1753  << "' for operand #" << pair.index();
1754  }
1755  return success();
1756 }
1757 
1758 //===----------------------------------------------------------------------===//
1759 // GPUModuleOp
1760 //===----------------------------------------------------------------------===//
1761 
1762 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1763  StringRef name, ArrayAttr targets,
1764  Attribute offloadingHandler) {
1765  result.addRegion()->emplaceBlock();
1766  Properties &props = result.getOrAddProperties<Properties>();
1767  if (targets)
1768  props.targets = targets;
1769  props.setSymName(builder.getStringAttr(name));
1770  props.offloadingHandler = offloadingHandler;
1771 }
1772 
1773 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1774  StringRef name, ArrayRef<Attribute> targets,
1775  Attribute offloadingHandler) {
1776  build(builder, result, name,
1777  targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1778  offloadingHandler);
1779 }
1780 
1781 bool GPUModuleOp::hasTarget(Attribute target) {
1782  if (ArrayAttr targets = getTargetsAttr())
1783  return llvm::count(targets.getValue(), target);
1784  return false;
1785 }
1786 
1787 void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1788  ArrayAttr &targetsAttr = getProperties().targets;
1789  SmallVector<Attribute> targetsVector(targets);
1790  targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1791 }
1792 
1793 LogicalResult GPUModuleOp::verify() {
1794  auto targets = getOperation()->getAttrOfType<ArrayAttr>("targets");
1795 
1796  if (!targets)
1797  return success();
1798 
1799  for (auto target : targets) {
1800  if (auto verifyTargetAttr =
1801  llvm::dyn_cast<TargetAttrVerifyInterface>(target)) {
1802  if (verifyTargetAttr.verifyTarget(getOperation()).failed())
1803  return failure();
1804  }
1805  }
1806  return success();
1807 }
1808 
1809 //===----------------------------------------------------------------------===//
1810 // GPUBinaryOp
1811 //===----------------------------------------------------------------------===//
1812 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1813  Attribute offloadingHandler, ArrayAttr objects) {
1814  auto &properties = result.getOrAddProperties<Properties>();
1815  result.attributes.push_back(builder.getNamedAttr(
1816  SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1817  properties.objects = objects;
1818  if (offloadingHandler)
1819  properties.offloadingHandler = offloadingHandler;
1820  else
1821  properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
1822 }
1823 
1824 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1825  Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1826  build(builder, result, name, offloadingHandler,
1827  objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
1828 }
1829 
1830 static ParseResult parseOffloadingHandler(OpAsmParser &parser,
1831  Attribute &offloadingHandler) {
1832  if (succeeded(parser.parseOptionalLess())) {
1833  if (parser.parseAttribute(offloadingHandler))
1834  return failure();
1835  if (parser.parseGreater())
1836  return failure();
1837  }
1838  if (!offloadingHandler)
1839  offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
1840  return success();
1841 }
1842 
1844  Attribute offloadingHandler) {
1845  if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
1846  printer << '<' << offloadingHandler << '>';
1847 }
1848 
1849 //===----------------------------------------------------------------------===//
1850 // GPUMemcpyOp
1851 //===----------------------------------------------------------------------===//
1852 
1853 LogicalResult MemcpyOp::verify() {
1854  auto srcType = getSrc().getType();
1855  auto dstType = getDst().getType();
1856 
1857  if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1858  return emitOpError("arguments have incompatible element type");
1859 
1860  if (failed(verifyCompatibleShape(srcType, dstType)))
1861  return emitOpError("arguments have incompatible shape");
1862 
1863  return success();
1864 }
1865 
1866 namespace {
1867 
1868 /// Erases a common case of copy ops where a destination value is used only by
1869 /// the copy op, alloc and dealloc ops.
1870 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1871  using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1872 
1873  LogicalResult matchAndRewrite(MemcpyOp op,
1874  PatternRewriter &rewriter) const override {
1875  Value dest = op.getDst();
1876  Operation *destDefOp = dest.getDefiningOp();
1877  // `dest` must be defined by an op having Allocate memory effect in order to
1878  // perform the folding.
1879  if (!destDefOp ||
1880  !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1881  return failure();
1882  // We can erase `op` iff `dest` has no other use apart from its
1883  // use by `op` and dealloc ops.
1884  if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1885  return user != op &&
1886  !hasSingleEffect<MemoryEffects::Free>(user, dest);
1887  }))
1888  return failure();
1889  // We can perform the folding if and only if op has a single async
1890  // dependency and produces an async token as result, or if it does not have
1891  // any async dependency and does not produce any async token result.
1892  if (op.getAsyncDependencies().size() > 1 ||
1893  ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1894  (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1895  return failure();
1896  rewriter.replaceOp(op, op.getAsyncDependencies());
1897  return success();
1898  }
1899 };
1900 
1901 } // end anonymous namespace
1902 
1903 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1904  MLIRContext *context) {
1905  results.add<EraseTrivialCopyOp>(context);
1906 }
1907 
1908 //===----------------------------------------------------------------------===//
1909 // GPU_SubgroupMmaLoadMatrixOp
1910 //===----------------------------------------------------------------------===//
1911 
1912 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1913  auto srcType = getSrcMemref().getType();
1914  auto resType = getRes().getType();
1915  auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
1916  auto operand = resMatrixType.getOperand();
1917  auto srcMemrefType = llvm::cast<MemRefType>(srcType);
1918 
1919  if (!srcMemrefType.isLastDimUnitStride())
1920  return emitError(
1921  "expected source memref most minor dim must have unit stride");
1922 
1923  if (operand != "AOp" && operand != "BOp" && operand != "COp")
1924  return emitError("only AOp, BOp and COp can be loaded");
1925 
1926  return success();
1927 }
1928 
1929 //===----------------------------------------------------------------------===//
1930 // GPU_SubgroupMmaStoreMatrixOp
1931 //===----------------------------------------------------------------------===//
1932 
1933 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1934  auto srcType = getSrc().getType();
1935  auto dstType = getDstMemref().getType();
1936  auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
1937  auto dstMemrefType = llvm::cast<MemRefType>(dstType);
1938 
1939  if (!dstMemrefType.isLastDimUnitStride())
1940  return emitError(
1941  "expected destination memref most minor dim must have unit stride");
1942 
1943  if (srcMatrixType.getOperand() != "COp")
1944  return emitError(
1945  "expected the operand matrix being stored to have 'COp' operand type");
1946 
1947  return success();
1948 }
1949 
1950 //===----------------------------------------------------------------------===//
1951 // GPU_SubgroupMmaComputeOp
1952 //===----------------------------------------------------------------------===//
1953 
1954 LogicalResult SubgroupMmaComputeOp::verify() {
1955  enum OperandMap { A, B, C };
1956  SmallVector<MMAMatrixType, 3> opTypes;
1957  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
1958  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
1959  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
1960 
1961  if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
1962  opTypes[C].getOperand() != "COp")
1963  return emitError("operands must be in the order AOp, BOp, COp");
1964 
1965  ArrayRef<int64_t> aShape, bShape, cShape;
1966  aShape = opTypes[A].getShape();
1967  bShape = opTypes[B].getShape();
1968  cShape = opTypes[C].getShape();
1969 
1970  if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1971  bShape[1] != cShape[1])
1972  return emitError("operand shapes do not satisfy matmul constraints");
1973 
1974  return success();
1975 }
1976 
1977 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
1978  SmallVectorImpl<::mlir::OpFoldResult> &results) {
1979  return memref::foldMemRefCast(*this);
1980 }
1981 
1982 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
1983  SmallVectorImpl<::mlir::OpFoldResult> &results) {
1984  return memref::foldMemRefCast(*this);
1985 }
1986 
1987 //===----------------------------------------------------------------------===//
1988 // GPU_WaitOp
1989 //===----------------------------------------------------------------------===//
1990 
1991 namespace {
1992 
1993 /// Remove gpu.wait op use of gpu.wait op def without async dependencies.
1994 /// %t = gpu.wait async [] // No async dependencies.
1995 /// ... gpu.wait ... [%t, ...] // %t can be removed.
1996 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
1997 public:
1998  using OpRewritePattern::OpRewritePattern;
1999 
2000  LogicalResult matchAndRewrite(WaitOp op,
2001  PatternRewriter &rewriter) const final {
2002  auto predicate = [](Value value) {
2003  auto waitOp = value.getDefiningOp<WaitOp>();
2004  return waitOp && waitOp->getNumOperands() == 0;
2005  };
2006  if (llvm::none_of(op.getAsyncDependencies(), predicate))
2007  return failure();
2008  SmallVector<Value> validOperands;
2009  for (Value operand : op->getOperands()) {
2010  if (predicate(operand))
2011  continue;
2012  validOperands.push_back(operand);
2013  }
2014  rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2015  return success();
2016  }
2017 };
2018 
2019 /// Simplify trivial gpu.wait ops for the following patterns.
2020 /// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2021 /// dependencies).
2022 /// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2023 /// %t0.
2024 /// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2025 /// dependencies nor return any token.
2026 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2027 public:
2028  using OpRewritePattern::OpRewritePattern;
2029 
2030  LogicalResult matchAndRewrite(WaitOp op,
2031  PatternRewriter &rewriter) const final {
2032  // Erase gpu.wait ops that neither have any async dependencies nor return
2033  // any async token.
2034  if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2035  rewriter.eraseOp(op);
2036  return success();
2037  }
2038  // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2039  if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2040  op.getAsyncToken()) {
2041  rewriter.replaceOp(op, op.getAsyncDependencies());
2042  return success();
2043  }
2044  // Erase %t = gpu.wait async ... ops, where %t has no uses.
2045  if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2046  rewriter.eraseOp(op);
2047  return success();
2048  }
2049  return failure();
2050  }
2051 };
2052 
2053 } // end anonymous namespace
2054 
2055 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2056  MLIRContext *context) {
2057  results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2058 }
2059 
2060 //===----------------------------------------------------------------------===//
2061 // GPU_AllocOp
2062 //===----------------------------------------------------------------------===//
2063 
2064 LogicalResult AllocOp::verify() {
2065  auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2066 
2067  if (getDynamicSizes().size() != memRefType.getNumDynamicDims())
2068  return emitOpError("dimension operand count does not equal memref "
2069  "dynamic dimension count");
2070 
2071  unsigned numSymbols = 0;
2072  if (!memRefType.getLayout().isIdentity())
2073  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2074  if (getSymbolOperands().size() != numSymbols) {
2075  return emitOpError(
2076  "symbol operand count does not equal memref symbol count");
2077  }
2078 
2079  return success();
2080 }
2081 
2082 namespace {
2083 
2084 /// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2085 /// `memref::AllocOp`.
2086 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2087  using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2088 
2089  LogicalResult matchAndRewrite(memref::DimOp dimOp,
2090  PatternRewriter &rewriter) const override {
2091  std::optional<int64_t> index = dimOp.getConstantIndex();
2092  if (!index)
2093  return failure();
2094 
2095  auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2096  if (!memrefType || index.value() >= memrefType.getRank() ||
2097  !memrefType.isDynamicDim(index.value()))
2098  return failure();
2099 
2100  auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2101  if (!alloc)
2102  return failure();
2103 
2104  Value substituteOp = *(alloc.getDynamicSizes().begin() +
2105  memrefType.getDynamicDimIndex(index.value()));
2106  rewriter.replaceOp(dimOp, substituteOp);
2107  return success();
2108  }
2109 };
2110 
2111 } // namespace
2112 
2113 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2114  MLIRContext *context) {
2115  results.add<SimplifyDimOfAllocOp>(context);
2116 }
2117 
2118 //===----------------------------------------------------------------------===//
2119 // GPU object attribute
2120 //===----------------------------------------------------------------------===//
2121 
2122 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2123  Attribute target, CompilationTarget format,
2124  StringAttr object, DictionaryAttr properties,
2125  KernelTableAttr kernels) {
2126  if (!target)
2127  return emitError() << "the target attribute cannot be null";
2128  if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2129  return success();
2130  return emitError() << "the target attribute must implement or promise the "
2131  "`gpu::TargetAttrInterface`";
2132 }
2133 
2134 namespace {
2135 ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2136  StringAttr &object) {
2137  std::optional<CompilationTarget> formatResult;
2138  StringRef enumKeyword;
2139  auto loc = odsParser.getCurrentLocation();
2140  if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2141  formatResult = CompilationTarget::Fatbin;
2142  if (!formatResult &&
2143  (formatResult =
2144  gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2145  odsParser.parseEqual())
2146  return odsParser.emitError(loc, "expected an equal sign");
2147  if (!formatResult)
2148  return odsParser.emitError(loc, "expected keyword for GPU object format");
2149  FailureOr<StringAttr> objectResult =
2150  FieldParser<StringAttr>::parse(odsParser);
2151  if (failed(objectResult))
2152  return odsParser.emitError(odsParser.getCurrentLocation(),
2153  "failed to parse GPU_ObjectAttr parameter "
2154  "'object' which is to be a `StringAttr`");
2155  format = *formatResult;
2156  object = *objectResult;
2157  return success();
2158 }
2159 
2160 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2161  StringAttr object) {
2162  if (format != CompilationTarget::Fatbin)
2163  odsParser << stringifyEnum(format) << " = ";
2164  odsParser << object;
2165 }
2166 } // namespace
2167 
2168 //===----------------------------------------------------------------------===//
2169 // GPU select object attribute
2170 //===----------------------------------------------------------------------===//
2171 
2172 LogicalResult
2173 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2174  Attribute target) {
2175  // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2176  if (target) {
2177  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2178  if (intAttr.getInt() < 0) {
2179  return emitError() << "the object index must be positive";
2180  }
2181  } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2182  return emitError()
2183  << "the target attribute must be a GPU Target attribute";
2184  }
2185  }
2186  return success();
2187 }
2188 
2189 //===----------------------------------------------------------------------===//
2190 // DynamicSharedMemoryOp
2191 //===----------------------------------------------------------------------===//
2192 
2193 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2194  if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2195  return emitOpError() << "must be inside an op with symbol table";
2196 
2197  MemRefType memrefType = getResultMemref().getType();
2198  // Check address space
2199  if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2200  return emitOpError() << "address space must be "
2201  << gpu::AddressSpaceAttr::getMnemonic() << "<"
2202  << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2203  }
2204  if (memrefType.hasStaticShape()) {
2205  return emitOpError() << "result memref type must be memref<?xi8, "
2206  "#gpu.address_space<workgroup>>";
2207  }
2208  return success();
2209 }
2210 
2211 //===----------------------------------------------------------------------===//
2212 // GPU WarpExecuteOnLane0Op
2213 //===----------------------------------------------------------------------===//
2214 
2215 void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2216  p << "(" << getLaneid() << ")";
2217 
2218  SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2219  auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2220  p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2221 
2222  if (!getArgs().empty())
2223  p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2224  if (!getResults().empty())
2225  p << " -> (" << getResults().getTypes() << ')';
2226  p << " ";
2227  p.printRegion(getRegion(),
2228  /*printEntryBlockArgs=*/true,
2229  /*printBlockTerminators=*/!getResults().empty());
2230  p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2231 }
2232 
2233 ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2234  OperationState &result) {
2235  // Create the region.
2236  result.regions.reserve(1);
2237  Region *warpRegion = result.addRegion();
2238 
2239  auto &builder = parser.getBuilder();
2240  OpAsmParser::UnresolvedOperand laneId;
2241 
2242  // Parse predicate operand.
2243  if (parser.parseLParen() ||
2244  parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2245  parser.parseRParen())
2246  return failure();
2247 
2248  int64_t warpSize;
2249  if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2250  parser.parseRSquare())
2251  return failure();
2252  result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2253  builder.getContext())),
2254  builder.getI64IntegerAttr(warpSize));
2255 
2256  if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2257  return failure();
2258 
2259  llvm::SMLoc inputsOperandsLoc;
2260  SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2261  SmallVector<Type> inputTypes;
2262  if (succeeded(parser.parseOptionalKeyword("args"))) {
2263  if (parser.parseLParen())
2264  return failure();
2265 
2266  inputsOperandsLoc = parser.getCurrentLocation();
2267  if (parser.parseOperandList(inputsOperands) ||
2268  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2269  return failure();
2270  }
2271  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2272  result.operands))
2273  return failure();
2274 
2275  // Parse optional results type list.
2276  if (parser.parseOptionalArrowTypeList(result.types))
2277  return failure();
2278  // Parse the region.
2279  if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2280  /*argTypes=*/{}))
2281  return failure();
2282  WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2283 
2284  // Parse the optional attribute list.
2285  if (parser.parseOptionalAttrDict(result.attributes))
2286  return failure();
2287  return success();
2288 }
2289 
2290 void WarpExecuteOnLane0Op::getSuccessorRegions(
2291  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2292  if (!point.isParent()) {
2293  regions.push_back(RegionSuccessor(getResults()));
2294  return;
2295  }
2296 
2297  // The warp region is always executed
2298  regions.push_back(RegionSuccessor(&getWarpRegion()));
2299 }
2300 
2301 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2302  TypeRange resultTypes, Value laneId,
2303  int64_t warpSize) {
2304  build(builder, result, resultTypes, laneId, warpSize,
2305  /*operands=*/std::nullopt, /*argTypes=*/std::nullopt);
2306 }
2307 
2308 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2309  TypeRange resultTypes, Value laneId,
2310  int64_t warpSize, ValueRange args,
2311  TypeRange blockArgTypes) {
2312  result.addOperands(laneId);
2313  result.addAttribute(getAttributeNames()[0],
2314  builder.getI64IntegerAttr(warpSize));
2315  result.addTypes(resultTypes);
2316  result.addOperands(args);
2317  assert(args.size() == blockArgTypes.size());
2318  OpBuilder::InsertionGuard guard(builder);
2319  Region *warpRegion = result.addRegion();
2320  Block *block = builder.createBlock(warpRegion);
2321  for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2322  block->addArgument(type, arg.getLoc());
2323 }
2324 
2325 /// Helper check if the distributed vector type is consistent with the expanded
2326 /// type and distributed size.
2327 static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2328  int64_t warpSize, Operation *op) {
2329  // If the types matches there is no distribution.
2330  if (expanded == distributed)
2331  return success();
2332  auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2333  auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2334  if (!expandedVecType || !distributedVecType)
2335  return op->emitOpError("expected vector type for distributed operands.");
2336  if (expandedVecType.getRank() != distributedVecType.getRank() ||
2337  expandedVecType.getElementType() != distributedVecType.getElementType())
2338  return op->emitOpError(
2339  "expected distributed vectors to have same rank and element type.");
2340 
2341  SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2342  for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2343  int64_t eDim = expandedVecType.getDimSize(i);
2344  int64_t dDim = distributedVecType.getDimSize(i);
2345  if (eDim == dDim)
2346  continue;
2347  if (eDim % dDim != 0)
2348  return op->emitOpError()
2349  << "expected expanded vector dimension #" << i << " (" << eDim
2350  << ") to be a multipler of the distributed vector dimension ("
2351  << dDim << ")";
2352  scales[i] = eDim / dDim;
2353  }
2354  if (std::accumulate(scales.begin(), scales.end(), 1,
2355  std::multiplies<int64_t>()) != warpSize)
2356  return op->emitOpError()
2357  << "incompatible distribution dimensions from " << expandedVecType
2358  << " to " << distributedVecType << " with warp size = " << warpSize;
2359 
2360  return success();
2361 }
2362 
2363 LogicalResult WarpExecuteOnLane0Op::verify() {
2364  if (getArgs().size() != getWarpRegion().getNumArguments())
2365  return emitOpError(
2366  "expected same number op arguments and block arguments.");
2367  auto yield =
2368  cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
2369  if (yield.getNumOperands() != getNumResults())
2370  return emitOpError(
2371  "expected same number of yield operands and return values.");
2372  int64_t warpSize = getWarpSize();
2373  for (auto [regionArg, arg] :
2374  llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2375  if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2376  warpSize, getOperation())))
2377  return failure();
2378  }
2379  for (auto [yieldOperand, result] :
2380  llvm::zip_equal(yield.getOperands(), getResults())) {
2381  if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2382  warpSize, getOperation())))
2383  return failure();
2384  }
2385  return success();
2386 }
2387 bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2388  return succeeded(
2389  verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2390 }
2391 
2392 //===----------------------------------------------------------------------===//
2393 // GPU KernelMetadataAttr
2394 //===----------------------------------------------------------------------===//
2395 
2396 KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2397  DictionaryAttr metadata) {
2398  assert(kernel && "invalid kernel");
2399  return get(kernel.getNameAttr(), kernel.getFunctionType(),
2400  kernel.getAllArgAttrs(), metadata);
2401 }
2402 
2403 KernelMetadataAttr
2404 KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2405  FunctionOpInterface kernel,
2406  DictionaryAttr metadata) {
2407  assert(kernel && "invalid kernel");
2408  return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2409  kernel.getAllArgAttrs(), metadata);
2410 }
2411 
2412 KernelMetadataAttr
2413 KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2414  if (attrs.empty())
2415  return *this;
2416  NamedAttrList attrList;
2417  if (DictionaryAttr dict = getMetadata())
2418  attrList.append(dict);
2419  attrList.append(attrs);
2420  return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2421  attrList.getDictionary(getContext()));
2422 }
2423 
2424 LogicalResult
2425 KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2426  StringAttr name, Type functionType,
2427  ArrayAttr argAttrs, DictionaryAttr metadata) {
2428  if (name.empty())
2429  return emitError() << "the kernel name can't be empty";
2430  if (argAttrs) {
2431  if (llvm::any_of(argAttrs, [](Attribute attr) {
2432  return !llvm::isa<DictionaryAttr>(attr);
2433  }))
2434  return emitError()
2435  << "all attributes in the array must be a dictionary attribute";
2436  }
2437  return success();
2438 }
2439 
2440 //===----------------------------------------------------------------------===//
2441 // GPU KernelTableAttr
2442 //===----------------------------------------------------------------------===//
2443 
2444 KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2445  ArrayRef<KernelMetadataAttr> kernels,
2446  bool isSorted) {
2447  // Note that `is_sorted` is always only invoked once even with assertions ON.
2448  assert((!isSorted || llvm::is_sorted(kernels)) &&
2449  "expected a sorted kernel array");
2450  // Immediately return the attribute if the array is sorted.
2451  if (isSorted || llvm::is_sorted(kernels))
2452  return Base::get(context, kernels);
2453  // Sort the array.
2454  SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2455  llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2456  return Base::get(context, kernelsTmp);
2457 }
2458 
2459 KernelTableAttr KernelTableAttr::getChecked(
2460  function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2461  ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2462  // Note that `is_sorted` is always only invoked once even with assertions ON.
2463  assert((!isSorted || llvm::is_sorted(kernels)) &&
2464  "expected a sorted kernel array");
2465  // Immediately return the attribute if the array is sorted.
2466  if (isSorted || llvm::is_sorted(kernels))
2467  return Base::getChecked(emitError, context, kernels);
2468  // Sort the array.
2469  SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2470  llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2471  return Base::getChecked(emitError, context, kernelsTmp);
2472 }
2473 
2474 LogicalResult
2475 KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2476  ArrayRef<KernelMetadataAttr> kernels) {
2477  if (kernels.size() < 2)
2478  return success();
2479  // Check that the kernels are uniquely named.
2480  if (std::adjacent_find(kernels.begin(), kernels.end(),
2481  [](KernelMetadataAttr l, KernelMetadataAttr r) {
2482  return l.getName() == r.getName();
2483  }) != kernels.end()) {
2484  return emitError() << "expected all kernels to be uniquely named";
2485  }
2486  return success();
2487 }
2488 
2489 KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2490  auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2491  return found ? *iterator : KernelMetadataAttr();
2492 }
2493 
2494 KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2495  auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2496  return found ? *iterator : KernelMetadataAttr();
2497 }
2498 
2499 //===----------------------------------------------------------------------===//
2500 // GPU target options
2501 //===----------------------------------------------------------------------===//
2502 
2503 TargetOptions::TargetOptions(
2504  StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2505  StringRef cmdOptions, StringRef elfSection,
2506  CompilationTarget compilationTarget,
2507  function_ref<SymbolTable *()> getSymbolTableCallback,
2508  function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2509  function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2510  function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2511  function_ref<void(StringRef)> isaCallback)
2512  : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
2513  cmdOptions, elfSection, compilationTarget,
2514  getSymbolTableCallback, initialLlvmIRCallback,
2515  linkedLlvmIRCallback, optimizedLlvmIRCallback,
2516  isaCallback) {}
2517 
2518 TargetOptions::TargetOptions(
2519  TypeID typeID, StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2520  StringRef cmdOptions, StringRef elfSection,
2521  CompilationTarget compilationTarget,
2522  function_ref<SymbolTable *()> getSymbolTableCallback,
2523  function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2524  function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2525  function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2526  function_ref<void(StringRef)> isaCallback)
2527  : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
2528  cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
2529  compilationTarget(compilationTarget),
2530  getSymbolTableCallback(getSymbolTableCallback),
2531  initialLlvmIRCallback(initialLlvmIRCallback),
2532  linkedLlvmIRCallback(linkedLlvmIRCallback),
2533  optimizedLlvmIRCallback(optimizedLlvmIRCallback),
2534  isaCallback(isaCallback), typeID(typeID) {}
2535 
2536 TypeID TargetOptions::getTypeID() const { return typeID; }
2537 
2538 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2539 
2540 ArrayRef<Attribute> TargetOptions::getLibrariesToLink() const {
2541  return librariesToLink;
2542 }
2543 
2544 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2545 
2546 StringRef TargetOptions::getELFSection() const { return elfSection; }
2547 
2548 SymbolTable *TargetOptions::getSymbolTable() const {
2549  return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2550 }
2551 
2552 function_ref<void(llvm::Module &)>
2553 TargetOptions::getInitialLlvmIRCallback() const {
2554  return initialLlvmIRCallback;
2555 }
2556 
2557 function_ref<void(llvm::Module &)>
2558 TargetOptions::getLinkedLlvmIRCallback() const {
2559  return linkedLlvmIRCallback;
2560 }
2561 
2562 function_ref<void(llvm::Module &)>
2563 TargetOptions::getOptimizedLlvmIRCallback() const {
2564  return optimizedLlvmIRCallback;
2565 }
2566 
2567 function_ref<void(StringRef)> TargetOptions::getISACallback() const {
2568  return isaCallback;
2569 }
2570 
2571 CompilationTarget TargetOptions::getCompilationTarget() const {
2572  return compilationTarget;
2573 }
2574 
2575 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2576  return CompilationTarget::Fatbin;
2577 }
2578 
2579 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2580 TargetOptions::tokenizeCmdOptions(const std::string &cmdOptions) {
2581  std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2582  llvm::StringSaver stringSaver(options.first);
2583  StringRef opts = cmdOptions;
2584  // For a correct tokenization of the command line options `opts` must be
2585  // unquoted, otherwise the tokenization function returns a single string: the
2586  // unquoted `cmdOptions` -which is not the desired behavior.
2587  // Remove any quotes if they are at the beginning and end of the string:
2588  if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2589  opts.consume_front("\""), opts.consume_back("\"");
2590  if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2591  opts.consume_front("'"), opts.consume_back("'");
2592 #ifdef _WIN32
2593  llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2594  /*MarkEOLs=*/false);
2595 #else
2596  llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2597  /*MarkEOLs=*/false);
2598 #endif // _WIN32
2599  return options;
2600 }
2601 
2602 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2605 }
2606 
2607 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2609  size_t startPos = cmdOptions.find(startsWith);
2610  if (startPos == std::string::npos)
2611  return {llvm::BumpPtrAllocator(), SmallVector<const char *>()};
2612 
2613  auto tokenized =
2614  tokenizeCmdOptions(cmdOptions.substr(startPos + startsWith.size()));
2615  cmdOptions.resize(startPos);
2616  return tokenized;
2617 }
2618 
2620 
2621 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2622 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2623 
2624 #define GET_ATTRDEF_CLASSES
2625 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2626 
2627 #define GET_OP_CLASSES
2628 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2629 
2630 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
Definition: GPUDialect.cpp:429
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
Definition: GPUDialect.cpp:592
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
Definition: GPUDialect.cpp:445
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
Definition: GPUDialect.cpp:903
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)
Prints a GPU function memory attribution.
Definition: GPUDialect.cpp:477
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
Definition: GPUDialect.cpp:570
static std::string getSparseHandleKeyword(SparseHandleKind kind)
Definition: GPUDialect.cpp:229
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
Definition: GPUDialect.cpp:319
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
Definition: GPUDialect.cpp:605
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
Definition: GPUDialect.cpp:466
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
Definition: GPUDialect.cpp:516
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
Definition: GPUDialect.cpp:854
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
Definition: GPUDialect.cpp:490
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
union mlir::linalg::@1195::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MINUI(lhs, rhs)
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:323
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:295
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
UnitAttr getUnitAttr()
Definition: Builders.cpp:94
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:76
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:88
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
IndexType getIndexType()
Definition: Builders.cpp:51
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition: Builders.cpp:100
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:90
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:95
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
This class helps build Operations.
Definition: Builders.h:204
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:428
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
Definition: Operation.cpp:256
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:811
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:602
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isIndex() const
Definition: Types.cpp:54
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static ConcreteT get(MLIRContext *ctx, Args &&...args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Definition: GPUDialect.cpp:140
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:125
Type getElementType() const
Get elementType of a single element.
Definition: GPUDialect.cpp:144
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
Definition: GPUDialect.cpp:148
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
Definition: GPUDialect.cpp:155
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Definition: GPUDialect.cpp:146
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
Definition: GPUDialect.cpp:131
unsigned getNumDims() const
Get number of dims.
Definition: GPUDialect.cpp:138
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
Definition: GPUDialect.cpp:669
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:39