MLIR  20.0.0git
GPUDialect.cpp
Go to the documentation of this file.
1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU kernel-related dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/SymbolTable.h"
29 #include "mlir/IR/TypeUtilities.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/ErrorHandling.h"
37 #include "llvm/Support/StringSaver.h"
38 #include <cassert>
39 #include <numeric>
40 
41 using namespace mlir;
42 using namespace mlir::gpu;
43 
44 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
45 
46 //===----------------------------------------------------------------------===//
47 // GPU Device Mapping Attributes
48 //===----------------------------------------------------------------------===//
49 
50 int64_t GPUBlockMappingAttr::getMappingId() const {
51  return static_cast<int64_t>(getBlock());
52 }
53 
54 bool GPUBlockMappingAttr::isLinearMapping() const {
55  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
56 }
57 
58 int64_t GPUBlockMappingAttr::getRelativeIndex() const {
59  return isLinearMapping()
60  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
61  : getMappingId();
62 }
63 
64 int64_t GPUWarpgroupMappingAttr::getMappingId() const {
65  return static_cast<int64_t>(getWarpgroup());
66 }
67 
68 bool GPUWarpgroupMappingAttr::isLinearMapping() const {
69  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
70 }
71 
72 int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
73  return isLinearMapping()
74  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
75  : getMappingId();
76 }
77 
78 int64_t GPUWarpMappingAttr::getMappingId() const {
79  return static_cast<int64_t>(getWarp());
80 }
81 
82 bool GPUWarpMappingAttr::isLinearMapping() const {
83  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
84 }
85 
86 int64_t GPUWarpMappingAttr::getRelativeIndex() const {
87  return isLinearMapping()
88  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
89  : getMappingId();
90 }
91 
92 int64_t GPUThreadMappingAttr::getMappingId() const {
93  return static_cast<int64_t>(getThread());
94 }
95 
96 bool GPUThreadMappingAttr::isLinearMapping() const {
97  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
98 }
99 
100 int64_t GPUThreadMappingAttr::getRelativeIndex() const {
101  return isLinearMapping()
102  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
103  : getMappingId();
104 }
105 
106 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
107  return static_cast<int64_t>(getAddressSpace());
108 }
109 
110 bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
111  llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
112 }
113 
114 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
115  llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
116 }
117 
118 //===----------------------------------------------------------------------===//
119 // MMAMatrixType
120 //===----------------------------------------------------------------------===//
121 
123  StringRef operand) {
124  return Base::get(elementType.getContext(), shape, elementType, operand);
125 }
126 
129  ArrayRef<int64_t> shape, Type elementType,
130  StringRef operand) {
131  return Base::getChecked(emitError, elementType.getContext(), shape,
132  elementType, operand);
133 }
134 
135 unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
136 
138  return getImpl()->getShape();
139 }
140 
141 Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
142 
143 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
144 
146  return elementType.isF16() || elementType.isF32() ||
147  elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
148  elementType.isInteger(32);
149 }
150 
151 LogicalResult
153  ArrayRef<int64_t> shape, Type elementType,
154  StringRef operand) {
155  if (operand != "AOp" && operand != "BOp" && operand != "COp")
156  return emitError() << "operand expected to be one of AOp, BOp or COp";
157 
158  if (shape.size() != 2)
159  return emitError() << "MMAMatrixType must have exactly two dimensions";
160 
161  if (!MMAMatrixType::isValidElementType(elementType))
162  return emitError()
163  << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
164 
165  return success();
166 }
167 
168 //===----------------------------------------------------------------------===//
169 // GPUDialect
170 //===----------------------------------------------------------------------===//
171 
172 bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
173  if (!memorySpace)
174  return false;
175  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
176  return gpuAttr.getValue() == getWorkgroupAddressSpace();
177  return false;
178 }
179 
180 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
181  Attribute memorySpace = type.getMemorySpace();
182  return isWorkgroupMemoryAddressSpace(memorySpace);
183 }
184 
185 bool GPUDialect::isKernel(Operation *op) {
186  UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
187  return static_cast<bool>(isKernelAttr);
188 }
189 
190 namespace {
191 /// This class defines the interface for handling inlining with gpu
192 /// operations.
193 struct GPUInlinerInterface : public DialectInlinerInterface {
195 
196  /// All gpu dialect ops can be inlined.
197  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
198  return true;
199  }
200 };
201 } // namespace
202 
203 void GPUDialect::initialize() {
204  addTypes<AsyncTokenType>();
205  addTypes<MMAMatrixType>();
206  addTypes<SparseDnTensorHandleType>();
207  addTypes<SparseSpMatHandleType>();
208  addTypes<SparseSpGEMMOpHandleType>();
209  addOperations<
210 #define GET_OP_LIST
211 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
212  >();
213  addAttributes<
214 #define GET_ATTRDEF_LIST
215 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
216  >();
217  addInterfaces<GPUInlinerInterface>();
218  declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
219  TerminatorOp>();
220 }
221 
222 static std::string getSparseHandleKeyword(SparseHandleKind kind) {
223  switch (kind) {
225  return "sparse.dntensor_handle";
227  return "sparse.spmat_handle";
229  return "sparse.spgemmop_handle";
230  }
231  llvm_unreachable("unknown sparse handle kind");
232  return "";
233 }
234 
236  // Parse the main keyword for the type.
237  StringRef keyword;
238  if (parser.parseKeyword(&keyword))
239  return Type();
240  MLIRContext *context = getContext();
241 
242  // Handle 'async token' types.
243  if (keyword == "async.token")
244  return AsyncTokenType::get(context);
245 
246  if (keyword == "mma_matrix") {
247  SMLoc beginLoc = parser.getNameLoc();
248 
249  // Parse '<'.
250  if (parser.parseLess())
251  return nullptr;
252 
253  // Parse the size and elementType.
254  SmallVector<int64_t> shape;
255  Type elementType;
256  if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
257  parser.parseType(elementType))
258  return nullptr;
259 
260  // Parse ','
261  if (parser.parseComma())
262  return nullptr;
263 
264  // Parse operand.
265  std::string operand;
266  if (failed(parser.parseOptionalString(&operand)))
267  return nullptr;
268 
269  // Parse '>'.
270  if (parser.parseGreater())
271  return nullptr;
272 
274  parser.getEncodedSourceLoc(beginLoc)),
275  shape, elementType, operand);
276  }
277 
279  return SparseDnTensorHandleType::get(context);
281  return SparseSpMatHandleType::get(context);
283  return SparseSpGEMMOpHandleType::get(context);
284 
285  parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
286  return Type();
287 }
288 // TODO: print refined type here. Notice that should be corresponding to the
289 // parser
290 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
291  TypeSwitch<Type>(type)
292  .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
293  .Case<SparseDnTensorHandleType>([&](Type) {
295  })
296  .Case<SparseSpMatHandleType>(
298  .Case<SparseSpGEMMOpHandleType>([&](Type) {
300  })
301  .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
302  os << "mma_matrix<";
303  auto shape = fragTy.getShape();
304  for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
305  os << *dim << 'x';
306  os << shape.back() << 'x' << fragTy.getElementType();
307  os << ", \"" << fragTy.getOperand() << "\"" << '>';
308  })
309  .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
310 }
311 
312 static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
313  NamedAttribute attr) {
314  auto array = dyn_cast<DenseI32ArrayAttr>(attr.getValue());
315  if (!array)
316  return op->emitOpError(Twine(attr.getName()) +
317  " must be a dense i32 array");
318  if (array.size() != 3)
319  return op->emitOpError(Twine(attr.getName()) +
320  " must contain exactly 3 elements");
321  return success();
322 }
323 
324 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
325  NamedAttribute attr) {
326  if (attr.getName() == getKnownBlockSizeAttrHelper().getName())
327  return verifyKnownLaunchSizeAttr(op, attr);
328  if (attr.getName() == getKnownGridSizeAttrHelper().getName())
329  return verifyKnownLaunchSizeAttr(op, attr);
330  if (!llvm::isa<UnitAttr>(attr.getValue()) ||
331  attr.getName() != getContainerModuleAttrName())
332  return success();
333 
334  auto module = dyn_cast<ModuleOp>(op);
335  if (!module)
336  return op->emitError("expected '")
337  << getContainerModuleAttrName() << "' attribute to be attached to '"
338  << ModuleOp::getOperationName() << '\'';
339 
340  auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
341  // Ignore launches that are nested more or less deep than functions in the
342  // module we are currently checking.
343  if (!launchOp->getParentOp() ||
344  launchOp->getParentOp()->getParentOp() != module)
345  return success();
346 
347  // Ignore launch ops with missing attributes here. The errors will be
348  // reported by the verifiers of those ops.
349  if (!launchOp->getAttrOfType<SymbolRefAttr>(
350  LaunchFuncOp::getKernelAttrName(launchOp->getName())))
351  return success();
352 
353  // Check that `launch_func` refers to a well-formed GPU kernel container.
354  StringAttr kernelContainerName = launchOp.getKernelModuleName();
355  Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
356  if (!kernelContainer)
357  return launchOp.emitOpError()
358  << "kernel container '" << kernelContainerName.getValue()
359  << "' is undefined";
360 
361  // If the container is a GPU binary op return success.
362  if (isa<BinaryOp>(kernelContainer))
363  return success();
364 
365  auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
366  if (!kernelModule)
367  return launchOp.emitOpError()
368  << "kernel module '" << kernelContainerName.getValue()
369  << "' is undefined";
370 
371  // Check that `launch_func` refers to a well-formed kernel function.
372  Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
373  if (!kernelFunc)
374  return launchOp.emitOpError("kernel function '")
375  << launchOp.getKernel() << "' is undefined";
376  auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
377  if (!kernelConvertedFunction) {
378  InFlightDiagnostic diag = launchOp.emitOpError()
379  << "referenced kernel '" << launchOp.getKernel()
380  << "' is not a function";
381  diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
382  return diag;
383  }
384 
385  if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
386  GPUDialect::getKernelFuncAttrName()))
387  return launchOp.emitOpError("kernel function is missing the '")
388  << GPUDialect::getKernelFuncAttrName() << "' attribute";
389 
390  // TODO: If the kernel isn't a GPU function (which happens during separate
391  // compilation), do not check type correspondence as it would require the
392  // verifier to be aware of the type conversion.
393  auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
394  if (!kernelGPUFunction)
395  return success();
396 
397  unsigned actualNumArguments = launchOp.getNumKernelOperands();
398  unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
399  if (expectedNumArguments != actualNumArguments)
400  return launchOp.emitOpError("got ")
401  << actualNumArguments << " kernel operands but expected "
402  << expectedNumArguments;
403 
404  auto functionType = kernelGPUFunction.getFunctionType();
405  for (unsigned i = 0; i < expectedNumArguments; ++i) {
406  if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
407  return launchOp.emitOpError("type of function argument ")
408  << i << " does not match";
409  }
410  }
411 
412  return success();
413  });
414 
415  return walkResult.wasInterrupted() ? failure() : success();
416 }
417 
418 /// Parses an optional list of async operands with an optional leading keyword.
419 /// (`async`)? (`[` ssa-id-list `]`)?
420 ///
421 /// This method is used by the tablegen assembly format for async ops as well.
422 static ParseResult parseAsyncDependencies(
423  OpAsmParser &parser, Type &asyncTokenType,
425  auto loc = parser.getCurrentLocation();
426  if (succeeded(parser.parseOptionalKeyword("async"))) {
427  if (parser.getNumResults() == 0)
428  return parser.emitError(loc, "needs to be named when marked 'async'");
429  asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
430  }
431  return parser.parseOperandList(asyncDependencies,
433 }
434 
435 /// Prints optional async dependencies with its leading keyword.
436 /// (`async`)? (`[` ssa-id-list `]`)?
437 // Used by the tablegen assembly format for several async ops.
439  Type asyncTokenType,
440  OperandRange asyncDependencies) {
441  if (asyncTokenType)
442  printer << "async";
443  if (asyncDependencies.empty())
444  return;
445  if (asyncTokenType)
446  printer << ' ';
447  printer << '[';
448  llvm::interleaveComma(asyncDependencies, printer);
449  printer << ']';
450 }
451 
452 // GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
453 /// Parses a GPU function memory attribution.
454 ///
455 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
456 /// (`private` `(` ssa-id-and-type-list `)`)?
457 ///
458 /// Note that this function parses only one of the two similar parts, with the
459 /// keyword provided as argument.
460 static ParseResult
461 parseAttributions(OpAsmParser &parser, StringRef keyword,
463  // If we could not parse the keyword, just assume empty list and succeed.
464  if (failed(parser.parseOptionalKeyword(keyword)))
465  return success();
466 
468  /*allowType=*/true);
469 }
470 
471 /// Prints a GPU function memory attribution.
472 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
473  ArrayRef<BlockArgument> values) {
474  if (values.empty())
475  return;
476 
477  p << ' ' << keyword << '(';
478  llvm::interleaveComma(
479  values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); });
480  p << ')';
481 }
482 
483 /// Verifies a GPU function memory attribution.
484 static LogicalResult verifyAttributions(Operation *op,
485  ArrayRef<BlockArgument> attributions,
486  gpu::AddressSpace memorySpace) {
487  for (Value v : attributions) {
488  auto type = llvm::dyn_cast<MemRefType>(v.getType());
489  if (!type)
490  return op->emitOpError() << "expected memref type in attribution";
491 
492  // We can only verify the address space if it hasn't already been lowered
493  // from the AddressSpaceAttr to a target-specific numeric value.
494  auto addressSpace =
495  llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
496  if (!addressSpace)
497  continue;
498  if (addressSpace.getValue() != memorySpace)
499  return op->emitOpError()
500  << "expected memory space " << stringifyAddressSpace(memorySpace)
501  << " in attribution";
502  }
503  return success();
504 }
505 
506 //===----------------------------------------------------------------------===//
507 // AllReduceOp
508 //===----------------------------------------------------------------------===//
509 
510 static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
511  Type resType) {
512  using Kind = gpu::AllReduceOperation;
513  if (llvm::is_contained(
514  {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
515  opName)) {
516  if (!isa<FloatType>(resType))
517  return failure();
518  }
519 
520  if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
521  Kind::AND, Kind::OR, Kind::XOR},
522  opName)) {
523  if (!isa<IntegerType>(resType))
524  return failure();
525  }
526 
527  return success();
528 }
529 
530 LogicalResult gpu::AllReduceOp::verifyRegions() {
531  if (getBody().empty() != getOp().has_value())
532  return emitError("expected either an op attribute or a non-empty body");
533  if (!getBody().empty()) {
534  if (getBody().getNumArguments() != 2)
535  return emitError("expected two region arguments");
536  for (auto argument : getBody().getArguments()) {
537  if (argument.getType() != getType())
538  return emitError("incorrect region argument type");
539  }
540  unsigned yieldCount = 0;
541  for (Block &block : getBody()) {
542  if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
543  if (yield.getNumOperands() != 1)
544  return emitError("expected one gpu.yield operand");
545  if (yield.getOperand(0).getType() != getType())
546  return emitError("incorrect gpu.yield type");
547  ++yieldCount;
548  }
549  }
550  if (yieldCount == 0)
551  return emitError("expected gpu.yield op in region");
552  } else {
553  gpu::AllReduceOperation opName = *getOp();
554  if (failed(verifyReduceOpAndType(opName, getType()))) {
555  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
556  << "` reduction operation is not compatible with type "
557  << getType();
558  }
559  }
560 
561  return success();
562 }
563 
565  auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
566  if (!launchOp)
567  return false;
568 
569  Region &body = launchOp.getBody();
570  assert(!body.empty() && "Invalid region");
571 
572  // Only convert ops in gpu::launch entry block for now.
573  return op->getBlock() == &body.front();
574 }
575 
576 OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
577  if (!getUniform() && canMakeGroupOpUniform(*this)) {
578  setUniform(true);
579  return getResult();
580  }
581 
582  return nullptr;
583 }
584 
585 // TODO: Support optional custom attributes (without dialect prefix).
586 static ParseResult parseAllReduceOperation(AsmParser &parser,
587  AllReduceOperationAttr &attr) {
588  StringRef enumStr;
589  if (!parser.parseOptionalKeyword(&enumStr)) {
590  std::optional<AllReduceOperation> op =
591  gpu::symbolizeAllReduceOperation(enumStr);
592  if (!op)
593  return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
594  attr = AllReduceOperationAttr::get(parser.getContext(), *op);
595  }
596  return success();
597 }
598 
599 static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
600  AllReduceOperationAttr attr) {
601  if (attr)
602  attr.print(printer);
603 }
604 
605 //===----------------------------------------------------------------------===//
606 // SubgroupReduceOp
607 //===----------------------------------------------------------------------===//
608 
609 LogicalResult gpu::SubgroupReduceOp::verify() {
610  Type elemType = getType();
611  if (auto vecTy = dyn_cast<VectorType>(elemType)) {
612  if (vecTy.isScalable())
613  return emitOpError() << "is not compatible with scalable vector types";
614 
615  elemType = vecTy.getElementType();
616  }
617 
618  gpu::AllReduceOperation opName = getOp();
619  if (failed(verifyReduceOpAndType(opName, elemType))) {
620  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
621  << "` reduction operation is not compatible with type "
622  << getType();
623  }
624 
625  auto clusterSize = getClusterSize();
626  if (clusterSize) {
627  uint32_t size = *clusterSize;
628  if (!llvm::isPowerOf2_32(size)) {
629  return emitOpError() << "cluster size " << size
630  << " is not a power of two";
631  }
632  }
633 
634  uint32_t stride = getClusterStride();
635  if (stride != 1 && !clusterSize) {
636  return emitOpError() << "cluster stride can only be specified if cluster "
637  "size is specified";
638  }
639  if (!llvm::isPowerOf2_32(stride)) {
640  return emitOpError() << "cluster stride " << stride
641  << " is not a power of two";
642  }
643 
644  return success();
645 }
646 
647 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
648  if (getClusterSize() == 1)
649  return getValue();
650 
651  if (!getUniform() && canMakeGroupOpUniform(*this)) {
652  setUniform(true);
653  return getResult();
654  }
655 
656  return nullptr;
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // AsyncOpInterface
661 //===----------------------------------------------------------------------===//
662 
664  op->insertOperands(0, {token});
665  if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
666  return;
667  auto attrName =
669  auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
670 
671  // Async dependencies is the only variadic operand.
672  if (!sizeAttr)
673  return;
674 
675  SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
676  ++sizes.front();
677  op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
678 }
679 
680 //===----------------------------------------------------------------------===//
681 // LaunchOp
682 //===----------------------------------------------------------------------===//
683 
684 void LaunchOp::build(OpBuilder &builder, OperationState &result,
685  Value gridSizeX, Value gridSizeY, Value gridSizeZ,
686  Value getBlockSizeX, Value getBlockSizeY,
687  Value getBlockSizeZ, Value dynamicSharedMemorySize,
688  Type asyncTokenType, ValueRange asyncDependencies,
689  TypeRange workgroupAttributions,
690  TypeRange privateAttributions, Value clusterSizeX,
691  Value clusterSizeY, Value clusterSizeZ) {
692  OpBuilder::InsertionGuard g(builder);
693 
694  // Add a WorkGroup attribution attribute. This attribute is required to
695  // identify private attributions in the list of block argguments.
696  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
697  builder.getI64IntegerAttr(workgroupAttributions.size()));
698 
699  // Add Op operands.
700  result.addOperands(asyncDependencies);
701  if (asyncTokenType)
702  result.types.push_back(builder.getType<AsyncTokenType>());
703 
704  // Add grid and block sizes as op operands, followed by the data operands.
705  result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
706  getBlockSizeY, getBlockSizeZ});
707  if (clusterSizeX)
708  result.addOperands(clusterSizeX);
709  if (clusterSizeY)
710  result.addOperands(clusterSizeY);
711  if (clusterSizeZ)
712  result.addOperands(clusterSizeZ);
713  if (dynamicSharedMemorySize)
714  result.addOperands(dynamicSharedMemorySize);
715 
716  // Create a kernel body region with kNumConfigRegionAttributes + N memory
717  // attributions, where the first kNumConfigRegionAttributes arguments have
718  // `index` type and the rest have the same types as the data operands.
719  Region *kernelRegion = result.addRegion();
720  Block *body = builder.createBlock(kernelRegion);
721  // TODO: Allow passing in proper locations here.
722  for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
723  body->addArgument(builder.getIndexType(), result.location);
724  // Add WorkGroup & Private attributions to the region arguments.
725  for (Type argTy : workgroupAttributions)
726  body->addArgument(argTy, result.location);
727  for (Type argTy : privateAttributions)
728  body->addArgument(argTy, result.location);
729  // Fill OperandSegmentSize Attribute.
730  SmallVector<int32_t, 11> segmentSizes(11, 1);
731  segmentSizes.front() = asyncDependencies.size();
732  segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
733  segmentSizes[7] = clusterSizeX ? 1 : 0;
734  segmentSizes[8] = clusterSizeY ? 1 : 0;
735  segmentSizes[9] = clusterSizeZ ? 1 : 0;
736  result.addAttribute(getOperandSegmentSizeAttr(),
737  builder.getDenseI32ArrayAttr(segmentSizes));
738 }
739 
740 KernelDim3 LaunchOp::getBlockIds() {
741  assert(!getBody().empty() && "LaunchOp body must not be empty.");
742  auto args = getBody().getArguments();
743  return KernelDim3{args[0], args[1], args[2]};
744 }
745 
746 KernelDim3 LaunchOp::getThreadIds() {
747  assert(!getBody().empty() && "LaunchOp body must not be empty.");
748  auto args = getBody().getArguments();
749  return KernelDim3{args[3], args[4], args[5]};
750 }
751 
752 KernelDim3 LaunchOp::getGridSize() {
753  assert(!getBody().empty() && "LaunchOp body must not be empty.");
754  auto args = getBody().getArguments();
755  return KernelDim3{args[6], args[7], args[8]};
756 }
757 
759  assert(!getBody().empty() && "LaunchOp body must not be empty.");
760  auto args = getBody().getArguments();
761  return KernelDim3{args[9], args[10], args[11]};
762 }
763 
764 std::optional<KernelDim3> LaunchOp::getClusterIds() {
765  assert(!getBody().empty() && "LaunchOp body must not be empty.");
766  if (!hasClusterSize())
767  return std::nullopt;
768  auto args = getBody().getArguments();
769  return KernelDim3{args[12], args[13], args[14]};
770 }
771 
772 std::optional<KernelDim3> LaunchOp::getClusterSize() {
773  assert(!getBody().empty() && "LaunchOp body must not be empty.");
774  if (!hasClusterSize())
775  return std::nullopt;
776  auto args = getBody().getArguments();
777  return KernelDim3{args[15], args[16], args[17]};
778 }
779 
780 KernelDim3 LaunchOp::getGridSizeOperandValues() {
781  auto operands = getOperands().drop_front(getAsyncDependencies().size());
782  return KernelDim3{operands[0], operands[1], operands[2]};
783 }
784 
785 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
786  auto operands = getOperands().drop_front(getAsyncDependencies().size());
787  return KernelDim3{operands[3], operands[4], operands[5]};
788 }
789 
790 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
791  auto operands = getOperands().drop_front(getAsyncDependencies().size());
792  if (!hasClusterSize())
793  return std::nullopt;
794  return KernelDim3{operands[6], operands[7], operands[8]};
795 }
796 
797 LogicalResult LaunchOp::verify() {
798  if (!(hasClusterSize()) &&
799  (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
800  return emitOpError() << "cluster size must be all present";
801  return success();
802 }
803 
804 LogicalResult LaunchOp::verifyRegions() {
805  // Kernel launch takes kNumConfigOperands leading operands for grid/block
806  // sizes and transforms them into kNumConfigRegionAttributes region arguments
807  // for block/thread identifiers and grid/block sizes.
808  if (!getBody().empty()) {
809  if (getBody().getNumArguments() <
810  kNumConfigRegionAttributes + getNumWorkgroupAttributions())
811  return emitOpError("unexpected number of region arguments");
812  }
813 
814  // Verify Attributions Address Spaces.
815  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
816  GPUDialect::getWorkgroupAddressSpace())) ||
817  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
818  GPUDialect::getPrivateAddressSpace())))
819  return failure();
820 
821  // Block terminators without successors are expected to exit the kernel region
822  // and must be `gpu.terminator`.
823  for (Block &block : getBody()) {
824  if (block.empty())
825  continue;
826  if (block.back().getNumSuccessors() != 0)
827  continue;
828  if (!isa<gpu::TerminatorOp>(&block.back())) {
829  return block.back()
830  .emitError()
831  .append("expected '", gpu::TerminatorOp::getOperationName(),
832  "' or a terminator with successors")
833  .attachNote(getLoc())
834  .append("in '", LaunchOp::getOperationName(), "' body region");
835  }
836  }
837 
838  if (getNumResults() == 0 && getAsyncToken())
839  return emitOpError("needs to be named when async keyword is specified");
840 
841  return success();
842 }
843 
844 // Pretty-print the kernel grid/block size assignment as
845 // (%iter-x, %iter-y, %iter-z) in
846 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
847 // where %size-* and %iter-* will correspond to the body region arguments.
849  KernelDim3 operands, KernelDim3 ids) {
850  p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
851  p << size.x << " = " << operands.x << ", ";
852  p << size.y << " = " << operands.y << ", ";
853  p << size.z << " = " << operands.z << ')';
854 }
855 
856 void LaunchOp::print(OpAsmPrinter &p) {
857  if (getAsyncToken()) {
858  p << " async";
859  if (!getAsyncDependencies().empty())
860  p << " [" << getAsyncDependencies() << ']';
861  }
862  // Print the launch configuration.
863  if (hasClusterSize()) {
864  p << ' ' << getClustersKeyword();
865  printSizeAssignment(p, getClusterSize().value(),
866  getClusterSizeOperandValues().value(),
867  getClusterIds().value());
868  }
869  p << ' ' << getBlocksKeyword();
870  printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
871  getBlockIds());
872  p << ' ' << getThreadsKeyword();
873  printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
874  getThreadIds());
875  if (getDynamicSharedMemorySize())
876  p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
877  << getDynamicSharedMemorySize();
878 
879  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
880  printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
881 
882  p << ' ';
883 
884  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
885  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
886  LaunchOp::getOperandSegmentSizeAttr(),
887  getNumWorkgroupAttributionsAttrName()});
888 }
889 
890 // Parse the size assignment blocks for blocks and threads. These have the form
891 // (%region_arg, %region_arg, %region_arg) in
892 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
893 // where %region_arg are percent-identifiers for the region arguments to be
894 // introduced further (SSA defs), and %operand are percent-identifiers for the
895 // SSA value uses.
896 static ParseResult
901  assert(indices.size() == 3 && "space for three indices expected");
904  /*allowResultNumber=*/false) ||
905  parser.parseKeyword("in") || parser.parseLParen())
906  return failure();
907  std::move(args.begin(), args.end(), indices.begin());
908 
909  for (int i = 0; i < 3; ++i) {
910  if (i != 0 && parser.parseComma())
911  return failure();
912  if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
913  parser.parseEqual() || parser.parseOperand(sizes[i]))
914  return failure();
915  }
916 
917  return parser.parseRParen();
918 }
919 
920 /// Parses a Launch operation.
921 /// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
922 /// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
923 /// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
924 /// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
925 /// memory-attribution
926 /// region attr-dict?
927 /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
928 ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
929  // Sizes of the grid and block.
931  sizes(LaunchOp::kNumConfigOperands);
932 
933  // Actual (data) operands passed to the kernel.
935 
936  // Region arguments to be created.
938  LaunchOp::kNumConfigRegionAttributes);
939 
940  // Parse optional async dependencies.
942  Type asyncTokenType;
943  if (failed(
944  parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
945  parser.resolveOperands(asyncDependencies, asyncTokenType,
946  result.operands))
947  return failure();
948  if (parser.getNumResults() > 0)
949  result.types.push_back(asyncTokenType);
950 
951  bool hasCluster = false;
952  if (succeeded(
953  parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) {
954  hasCluster = true;
955  sizes.resize(9);
956  regionArgs.resize(18);
957  }
959  MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
960 
961  // Last three segment assigns the cluster size. In the region argument
962  // list, this is last 6 arguments.
963  if (hasCluster) {
964  if (parseSizeAssignment(parser, sizesRef.drop_front(6),
965  regionArgsRef.slice(15, 3),
966  regionArgsRef.slice(12, 3)))
967  return failure();
968  }
969  // Parse the size assignment segments: the first segment assigns grid sizes
970  // and defines values for block identifiers; the second segment assigns block
971  // sizes and defines values for thread identifiers. In the region argument
972  // list, identifiers precede sizes, and block-related values precede
973  // thread-related values.
974  if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
975  parseSizeAssignment(parser, sizesRef.take_front(3),
976  regionArgsRef.slice(6, 3),
977  regionArgsRef.slice(0, 3)) ||
978  parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
979  parseSizeAssignment(parser, sizesRef.drop_front(3),
980  regionArgsRef.slice(9, 3),
981  regionArgsRef.slice(3, 3)) ||
982  parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
983  result.operands))
984  return failure();
985 
986  OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
987  bool hasDynamicSharedMemorySize = false;
988  if (!parser.parseOptionalKeyword(
989  LaunchOp::getDynamicSharedMemorySizeKeyword())) {
990  hasDynamicSharedMemorySize = true;
991  if (parser.parseOperand(dynamicSharedMemorySize) ||
992  parser.resolveOperand(dynamicSharedMemorySize,
993  parser.getBuilder().getI32Type(),
994  result.operands))
995  return failure();
996  }
997 
998  // Create the region arguments, it has kNumConfigRegionAttributes arguments
999  // that correspond to block/thread identifiers and grid/block sizes, all
1000  // having `index` type, a variadic number of WorkGroup Attributions and
1001  // a variadic number of Private Attributions. The number of WorkGroup
1002  // Attributions is stored in the attr with name:
1003  // LaunchOp::getNumWorkgroupAttributionsAttrName().
1004  Type index = parser.getBuilder().getIndexType();
1006  LaunchOp::kNumConfigRegionAttributes + 6, index);
1007 
1008  SmallVector<OpAsmParser::Argument> regionArguments;
1009  for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1011  arg.ssaName = std::get<0>(ssaValueAndType);
1012  arg.type = std::get<1>(ssaValueAndType);
1013  regionArguments.push_back(arg);
1014  }
1015 
1016  Builder &builder = parser.getBuilder();
1017  // Parse workgroup memory attributions.
1018  if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
1019  regionArguments)))
1020  return failure();
1021 
1022  // Store the number of operands we just parsed as the number of workgroup
1023  // memory attributions.
1024  unsigned numWorkgroupAttrs = regionArguments.size() -
1025  LaunchOp::kNumConfigRegionAttributes -
1026  (hasCluster ? 6 : 0);
1027  result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1028  builder.getI64IntegerAttr(numWorkgroupAttrs));
1029 
1030  // Parse private memory attributions.
1031  if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
1032  regionArguments)))
1033  return failure();
1034 
1035  // Introduce the body region and parse it. The region has
1036  // kNumConfigRegionAttributes arguments that correspond to
1037  // block/thread identifiers and grid/block sizes, all having `index` type.
1038  Region *body = result.addRegion();
1039  if (parser.parseRegion(*body, regionArguments) ||
1040  parser.parseOptionalAttrDict(result.attributes))
1041  return failure();
1042 
1043  SmallVector<int32_t, 11> segmentSizes(11, 1);
1044  segmentSizes.front() = asyncDependencies.size();
1045 
1046  if (!hasCluster) {
1047  segmentSizes[7] = 0;
1048  segmentSizes[8] = 0;
1049  segmentSizes[9] = 0;
1050  }
1051  segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1052  result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1053  parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1054  return success();
1055 }
1056 
1057 /// Simplify the gpu.launch when the range of a thread or block ID is
1058 /// trivially known to be one.
1059 struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1061  LogicalResult matchAndRewrite(LaunchOp op,
1062  PatternRewriter &rewriter) const override {
1063  // If the range implies a single value for `id`, replace `id`'s uses by
1064  // zero.
1065  Value zero;
1066  bool simplified = false;
1067  auto constPropIdUses = [&](Value id, Value size) {
1068  // Check if size is trivially one.
1069  if (!matchPattern(size, m_One()))
1070  return;
1071  if (id.getUses().empty())
1072  return;
1073  if (!simplified) {
1074  // Create a zero value the first time.
1075  OpBuilder::InsertionGuard guard(rewriter);
1076  rewriter.setInsertionPointToStart(&op.getBody().front());
1077  zero =
1078  rewriter.create<arith::ConstantIndexOp>(op.getLoc(), /*value=*/0);
1079  }
1080  rewriter.replaceAllUsesWith(id, zero);
1081  simplified = true;
1082  };
1083  constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1084  constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1085  constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1086  constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1087  constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1088  constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1089 
1090  return success(simplified);
1091  }
1092 };
1093 
1094 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1095  MLIRContext *context) {
1096  rewrites.add<FoldLaunchArguments>(context);
1097 }
1098 
1099 /// Adds a new block argument that corresponds to buffers located in
1100 /// workgroup memory.
1101 BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1102  auto attrName = getNumWorkgroupAttributionsAttrName();
1103  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1104  (*this)->setAttr(attrName,
1105  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1106  return getBody().insertArgument(
1107  LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1108 }
1109 
1110 /// Adds a new block argument that corresponds to buffers located in
1111 /// private memory.
1112 BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1113  // Buffers on the private memory always come after buffers on the workgroup
1114  // memory.
1115  return getBody().addArgument(type, loc);
1116 }
1117 
1118 //===----------------------------------------------------------------------===//
1119 // LaunchFuncOp
1120 //===----------------------------------------------------------------------===//
1121 
1122 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1123  SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1124  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1125  ValueRange kernelOperands, Type asyncTokenType,
1126  ValueRange asyncDependencies,
1127  std::optional<KernelDim3> clusterSize) {
1128  assert(kernelSymbol.getNestedReferences().size() == 1 &&
1129  "expected a symbol reference with a single nested reference");
1130  result.addOperands(asyncDependencies);
1131  if (asyncTokenType)
1132  result.types.push_back(builder.getType<AsyncTokenType>());
1133 
1134  // Add grid and block sizes as op operands, followed by the data operands.
1135  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1136  getBlockSize.y, getBlockSize.z});
1137  if (clusterSize.has_value())
1138  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1139  if (dynamicSharedMemorySize)
1140  result.addOperands(dynamicSharedMemorySize);
1141  result.addOperands(kernelOperands);
1142 
1143  Properties &prop = result.getOrAddProperties<Properties>();
1144  prop.kernel = kernelSymbol;
1145  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1146  // Initialize the segment sizes to 1.
1147  for (auto &sz : prop.operandSegmentSizes)
1148  sz = 1;
1149  prop.operandSegmentSizes[0] = asyncDependencies.size();
1150  if (!clusterSize.has_value()) {
1151  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1152  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1153  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1154  }
1155  prop.operandSegmentSizes[segmentSizesLen - 3] =
1156  dynamicSharedMemorySize ? 1 : 0;
1157  prop.operandSegmentSizes[segmentSizesLen - 2] =
1158  static_cast<int32_t>(kernelOperands.size());
1159  prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1160 }
1161 
1162 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1163  GPUFuncOp kernelFunc, KernelDim3 gridSize,
1164  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1165  ValueRange kernelOperands, Type asyncTokenType,
1166  ValueRange asyncDependencies,
1167  std::optional<KernelDim3> clusterSize) {
1168  auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1169  auto kernelSymbol =
1170  SymbolRefAttr::get(kernelModule.getNameAttr(),
1171  {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1172  build(builder, result, kernelSymbol, gridSize, getBlockSize,
1173  dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1174  asyncDependencies, clusterSize);
1175 }
1176 
1177 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1178  SymbolRefAttr kernel, KernelDim3 gridSize,
1179  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1180  ValueRange kernelOperands, Value asyncObject,
1181  std::optional<KernelDim3> clusterSize) {
1182  // Add grid and block sizes as op operands, followed by the data operands.
1183  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1184  getBlockSize.y, getBlockSize.z});
1185  if (clusterSize.has_value())
1186  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1187  if (dynamicSharedMemorySize)
1188  result.addOperands(dynamicSharedMemorySize);
1189  result.addOperands(kernelOperands);
1190  if (asyncObject)
1191  result.addOperands(asyncObject);
1192  Properties &prop = result.getOrAddProperties<Properties>();
1193  prop.kernel = kernel;
1194  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1195  // Initialize the segment sizes to 1.
1196  for (auto &sz : prop.operandSegmentSizes)
1197  sz = 1;
1198  prop.operandSegmentSizes[0] = 0;
1199  if (!clusterSize.has_value()) {
1200  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1201  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1202  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1203  }
1204  prop.operandSegmentSizes[segmentSizesLen - 3] =
1205  dynamicSharedMemorySize ? 1 : 0;
1206  prop.operandSegmentSizes[segmentSizesLen - 2] =
1207  static_cast<int32_t>(kernelOperands.size());
1208  prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1209 }
1210 
1211 StringAttr LaunchFuncOp::getKernelModuleName() {
1212  return getKernel().getRootReference();
1213 }
1214 
1215 StringAttr LaunchFuncOp::getKernelName() {
1216  return getKernel().getLeafReference();
1217 }
1218 
1219 unsigned LaunchFuncOp::getNumKernelOperands() {
1220  return getKernelOperands().size();
1221 }
1222 
1223 Value LaunchFuncOp::getKernelOperand(unsigned i) {
1224  return getKernelOperands()[i];
1225 }
1226 
1227 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1228  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1229  return KernelDim3{operands[0], operands[1], operands[2]};
1230 }
1231 
1232 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1233  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1234  return KernelDim3{operands[3], operands[4], operands[5]};
1235 }
1236 
1237 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1238  assert(hasClusterSize() &&
1239  "cluster size is not set, check hasClusterSize() first");
1240  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1241  return KernelDim3{operands[6], operands[7], operands[8]};
1242 }
1243 
1244 LogicalResult LaunchFuncOp::verify() {
1245  auto module = (*this)->getParentOfType<ModuleOp>();
1246  if (!module)
1247  return emitOpError("expected to belong to a module");
1248 
1249  if (!module->getAttrOfType<UnitAttr>(
1250  GPUDialect::getContainerModuleAttrName()))
1251  return emitOpError("expected the closest surrounding module to have the '" +
1252  GPUDialect::getContainerModuleAttrName() +
1253  "' attribute");
1254 
1255  if (hasClusterSize()) {
1256  if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1257  getClusterSizeZ().getType() != getClusterSizeX().getType())
1258  return emitOpError()
1259  << "expects types of the cluster dimensions must be the same";
1260  }
1261 
1262  return success();
1263 }
1264 
1265 static ParseResult
1267  std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1268  Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1269  if (succeeded(parser.parseOptionalColon())) {
1270  if (parser.parseType(dimTy))
1271  return failure();
1272  } else {
1273  dimTy = IndexType::get(parser.getContext());
1274  }
1275  if (clusterValue.has_value()) {
1276  clusterXTy = clusterYTy = clusterZTy = dimTy;
1277  }
1278  return success();
1279 }
1280 
1281 static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1282  Value clusterValue, Type clusterXTy,
1283  Type clusterYTy, Type clusterZTy) {
1284  if (!dimTy.isIndex())
1285  printer << ": " << dimTy;
1286 }
1287 
1288 static ParseResult parseLaunchFuncOperands(
1289  OpAsmParser &parser,
1291  SmallVectorImpl<Type> &argTypes) {
1292  if (parser.parseOptionalKeyword("args"))
1293  return success();
1294 
1295  auto parseElement = [&]() -> ParseResult {
1296  return failure(parser.parseOperand(argNames.emplace_back()) ||
1297  parser.parseColonType(argTypes.emplace_back()));
1298  };
1299 
1301  parseElement, " in argument list");
1302 }
1303 
1305  OperandRange operands, TypeRange types) {
1306  if (operands.empty())
1307  return;
1308  printer << "args(";
1309  llvm::interleaveComma(llvm::zip(operands, types), printer,
1310  [&](const auto &pair) {
1311  printer.printOperand(std::get<0>(pair));
1312  printer << " : ";
1313  printer.printType(std::get<1>(pair));
1314  });
1315  printer << ")";
1316 }
1317 
1318 //===----------------------------------------------------------------------===//
1319 // ShuffleOp
1320 //===----------------------------------------------------------------------===//
1321 
1322 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1323  int32_t offset, int32_t width, ShuffleMode mode) {
1324  build(builder, result, value,
1325  builder.create<arith::ConstantOp>(result.location,
1326  builder.getI32IntegerAttr(offset)),
1327  builder.create<arith::ConstantOp>(result.location,
1328  builder.getI32IntegerAttr(width)),
1329  mode);
1330 }
1331 
1332 //===----------------------------------------------------------------------===//
1333 // BarrierOp
1334 //===----------------------------------------------------------------------===//
1335 
1336 namespace {
1337 
1338 /// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1339 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1340  PatternRewriter &rewriter) {
1341  if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1342  rewriter.eraseOp(op);
1343  return success();
1344  }
1345  return failure();
1346 }
1347 
1348 } // end anonymous namespace
1349 
1350 void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1351  MLIRContext *context) {
1352  results.add(eraseRedundantGpuBarrierOps);
1353 }
1354 
1355 //===----------------------------------------------------------------------===//
1356 // GPUFuncOp
1357 //===----------------------------------------------------------------------===//
1358 
1359 /// Adds a new block argument that corresponds to buffers located in
1360 /// workgroup memory.
1361 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1362  auto attrName = getNumWorkgroupAttributionsAttrName();
1363  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1364  (*this)->setAttr(attrName,
1365  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1366  return getBody().insertArgument(
1367  getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1368 }
1369 
1370 /// Adds a new block argument that corresponds to buffers located in
1371 /// private memory.
1372 BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1373  // Buffers on the private memory always come after buffers on the workgroup
1374  // memory.
1375  return getBody().addArgument(type, loc);
1376 }
1377 
1378 void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1379  StringRef name, FunctionType type,
1380  TypeRange workgroupAttributions,
1381  TypeRange privateAttributions,
1382  ArrayRef<NamedAttribute> attrs) {
1383  OpBuilder::InsertionGuard g(builder);
1384 
1386  builder.getStringAttr(name));
1387  result.addAttribute(getFunctionTypeAttrName(result.name),
1388  TypeAttr::get(type));
1389  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1390  builder.getI64IntegerAttr(workgroupAttributions.size()));
1391  result.addAttributes(attrs);
1392  Region *body = result.addRegion();
1393  Block *entryBlock = builder.createBlock(body);
1394 
1395  // TODO: Allow passing in proper locations here.
1396  for (Type argTy : type.getInputs())
1397  entryBlock->addArgument(argTy, result.location);
1398  for (Type argTy : workgroupAttributions)
1399  entryBlock->addArgument(argTy, result.location);
1400  for (Type argTy : privateAttributions)
1401  entryBlock->addArgument(argTy, result.location);
1402 }
1403 
1404 /// Parses a GPU function memory attribution.
1405 ///
1406 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1407 /// (`private` `(` ssa-id-and-type-list `)`)?
1408 ///
1409 /// Note that this function parses only one of the two similar parts, with the
1410 /// keyword provided as argument.
1411 static ParseResult
1412 parseAttributions(OpAsmParser &parser, StringRef keyword,
1414  Attribute &attributionAttrs) {
1415  // If we could not parse the keyword, just assume empty list and succeed.
1416  if (failed(parser.parseOptionalKeyword(keyword)))
1417  return success();
1418 
1419  size_t existingArgs = args.size();
1420  ParseResult result =
1422  /*allowType=*/true, /*allowAttrs=*/true);
1423  if (failed(result))
1424  return result;
1425 
1426  bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1427  [](const OpAsmParser::Argument &arg) -> bool {
1428  return arg.attrs && !arg.attrs.empty();
1429  });
1430  if (!hadAttrs) {
1431  attributionAttrs = nullptr;
1432  return result;
1433  }
1434 
1435  Builder &builder = parser.getBuilder();
1436  SmallVector<Attribute> attributionAttrsVec;
1437  for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1438  if (!argument.attrs)
1439  attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1440  else
1441  attributionAttrsVec.push_back(argument.attrs);
1442  }
1443  attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1444  return result;
1445 }
1446 
1447 /// Parses a GPU function.
1448 ///
1449 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1450 /// (`->` function-result-list)? memory-attribution `kernel`?
1451 /// function-attributes? region
1452 ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
1454  SmallVector<DictionaryAttr> resultAttrs;
1455  SmallVector<Type> resultTypes;
1456  bool isVariadic;
1457 
1458  // Parse the function name.
1459  StringAttr nameAttr;
1460  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1461  result.attributes))
1462  return failure();
1463 
1464  auto signatureLocation = parser.getCurrentLocation();
1466  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1467  resultAttrs)))
1468  return failure();
1469 
1470  if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1471  return parser.emitError(signatureLocation)
1472  << "gpu.func requires named arguments";
1473 
1474  // Construct the function type. More types will be added to the region, but
1475  // not to the function type.
1476  Builder &builder = parser.getBuilder();
1477 
1478  SmallVector<Type> argTypes;
1479  for (auto &arg : entryArgs)
1480  argTypes.push_back(arg.type);
1481  auto type = builder.getFunctionType(argTypes, resultTypes);
1482  result.addAttribute(getFunctionTypeAttrName(result.name),
1483  TypeAttr::get(type));
1484 
1486  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1487  getResAttrsAttrName(result.name));
1488 
1489  Attribute workgroupAttributionAttrs;
1490  // Parse workgroup memory attributions.
1491  if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1492  entryArgs, workgroupAttributionAttrs)))
1493  return failure();
1494 
1495  // Store the number of operands we just parsed as the number of workgroup
1496  // memory attributions.
1497  unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1498  result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1499  builder.getI64IntegerAttr(numWorkgroupAttrs));
1500  if (workgroupAttributionAttrs)
1501  result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1502  workgroupAttributionAttrs);
1503 
1504  Attribute privateAttributionAttrs;
1505  // Parse private memory attributions.
1506  if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1507  entryArgs, privateAttributionAttrs)))
1508  return failure();
1509  if (privateAttributionAttrs)
1510  result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1511  privateAttributionAttrs);
1512 
1513  // Parse the kernel attribute if present.
1514  if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1515  result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1516  builder.getUnitAttr());
1517 
1518  // Parse attributes.
1519  if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1520  return failure();
1521 
1522  // Parse the region. If no argument names were provided, take all names
1523  // (including those of attributions) from the entry block.
1524  auto *body = result.addRegion();
1525  return parser.parseRegion(*body, entryArgs);
1526 }
1527 
1528 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
1529  ArrayRef<BlockArgument> values,
1530  ArrayAttr attributes) {
1531  if (values.empty())
1532  return;
1533 
1534  p << ' ' << keyword << '(';
1535  llvm::interleaveComma(
1536  llvm::enumerate(values), p, [&p, attributes](auto pair) {
1537  BlockArgument v = pair.value();
1538  p << v << " : " << v.getType();
1539 
1540  size_t attributionIndex = pair.index();
1541  DictionaryAttr attrs;
1542  if (attributes && attributionIndex < attributes.size())
1543  attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1544  if (attrs)
1545  p.printOptionalAttrDict(attrs.getValue());
1546  });
1547  p << ')';
1548 }
1549 
1550 void GPUFuncOp::print(OpAsmPrinter &p) {
1551  p << ' ';
1552  p.printSymbolName(getName());
1553 
1554  FunctionType type = getFunctionType();
1555  function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1556  /*isVariadic=*/false,
1557  type.getResults());
1558 
1559  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1560  getWorkgroupAttribAttrs().value_or(nullptr));
1561  printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1562  getPrivateAttribAttrs().value_or(nullptr));
1563  if (isKernel())
1564  p << ' ' << getKernelKeyword();
1565 
1567  p, *this,
1568  {getNumWorkgroupAttributionsAttrName(),
1569  GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1570  getArgAttrsAttrName(), getResAttrsAttrName(),
1571  getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1572  p << ' ';
1573  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1574 }
1575 
1576 static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1577  StringAttr attrName) {
1578  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1579  if (!allAttrs || index >= allAttrs.size())
1580  return DictionaryAttr();
1581  return llvm::cast<DictionaryAttr>(allAttrs[index]);
1582 }
1583 
1584 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1585  return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1586 }
1587 
1588 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1589  return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1590 }
1591 
1592 static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1593  DictionaryAttr value, StringAttr attrName) {
1594  MLIRContext *ctx = op.getContext();
1595  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1596  SmallVector<Attribute> elements;
1597  if (allAttrs)
1598  elements.append(allAttrs.begin(), allAttrs.end());
1599  while (elements.size() <= index)
1600  elements.push_back(DictionaryAttr::get(ctx));
1601  if (!value)
1602  elements[index] = DictionaryAttr::get(ctx);
1603  else
1604  elements[index] = value;
1605  ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1606  op->setAttr(attrName, newValue);
1607 }
1608 
1609 void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1610  DictionaryAttr value) {
1611  setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1612 }
1613 
1614 void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1615  DictionaryAttr value) {
1616  setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1617 }
1618 
1619 static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1620  StringAttr name, StringAttr attrsName) {
1621  DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1622  if (!dict)
1623  return Attribute();
1624  return dict.get(name);
1625 }
1626 
1627 Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1628  StringAttr name) {
1629  assert(index < getNumWorkgroupAttributions() &&
1630  "index must map to a workgroup attribution");
1631  return getAttributionAttr(*this, index, name,
1632  getWorkgroupAttribAttrsAttrName());
1633 }
1634 
1635 Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1636  StringAttr name) {
1637  assert(index < getNumPrivateAttributions() &&
1638  "index must map to a private attribution");
1639  return getAttributionAttr(*this, index, name,
1640  getPrivateAttribAttrsAttrName());
1641 }
1642 
1643 static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1644  Attribute value, StringAttr attrsName) {
1645  MLIRContext *ctx = op.getContext();
1647  DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1648  if (oldDict)
1649  elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1650 
1651  bool found = false;
1652  bool mustSort = true;
1653  for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1654  if (elems[i].getName() == name) {
1655  found = true;
1656  if (!value) {
1657  std::swap(elems[i], elems[elems.size() - 1]);
1658  elems.pop_back();
1659  } else {
1660  mustSort = false;
1661  elems[i] = NamedAttribute(elems[i].getName(), value);
1662  }
1663  break;
1664  }
1665  }
1666  if (!found) {
1667  if (!value)
1668  return;
1669  elems.emplace_back(name, value);
1670  }
1671  if (mustSort) {
1672  DictionaryAttr::sortInPlace(elems);
1673  }
1674  auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1675  setAttributionAttrs(op, index, newDict, attrsName);
1676 }
1677 
1678 void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1679  Attribute value) {
1680  assert(index < getNumWorkgroupAttributions() &&
1681  "index must map to a workgroup attribution");
1682  setAttributionAttr(*this, index, name, value,
1683  getWorkgroupAttribAttrsAttrName());
1684 }
1685 
1686 void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1687  Attribute value) {
1688  assert(index < getNumPrivateAttributions() &&
1689  "index must map to a private attribution");
1690  setAttributionAttr(*this, index, name, value,
1691  getPrivateAttribAttrsAttrName());
1692 }
1693 
1694 LogicalResult GPUFuncOp::verifyType() {
1695  if (isKernel() && getFunctionType().getNumResults() != 0)
1696  return emitOpError() << "expected void return type for kernel function";
1697 
1698  return success();
1699 }
1700 
1701 /// Verifies the body of the function.
1702 LogicalResult GPUFuncOp::verifyBody() {
1703  if (empty())
1704  return emitOpError() << "expected body with at least one block";
1705  unsigned numFuncArguments = getNumArguments();
1706  unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1707  unsigned numBlockArguments = front().getNumArguments();
1708  if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1709  return emitOpError() << "expected at least "
1710  << numFuncArguments + numWorkgroupAttributions
1711  << " arguments to body region";
1712 
1713  ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1714  for (unsigned i = 0; i < numFuncArguments; ++i) {
1715  Type blockArgType = front().getArgument(i).getType();
1716  if (funcArgTypes[i] != blockArgType)
1717  return emitOpError() << "expected body region argument #" << i
1718  << " to be of type " << funcArgTypes[i] << ", got "
1719  << blockArgType;
1720  }
1721 
1722  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1723  GPUDialect::getWorkgroupAddressSpace())) ||
1724  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1725  GPUDialect::getPrivateAddressSpace())))
1726  return failure();
1727 
1728  return success();
1729 }
1730 
1731 //===----------------------------------------------------------------------===//
1732 // ReturnOp
1733 //===----------------------------------------------------------------------===//
1734 
1735 LogicalResult gpu::ReturnOp::verify() {
1736  GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1737 
1738  FunctionType funType = function.getFunctionType();
1739 
1740  if (funType.getNumResults() != getOperands().size())
1741  return emitOpError()
1742  .append("expected ", funType.getNumResults(), " result operands")
1743  .attachNote(function.getLoc())
1744  .append("return type declared here");
1745 
1746  for (const auto &pair : llvm::enumerate(
1747  llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1748  auto [type, operand] = pair.value();
1749  if (type != operand.getType())
1750  return emitOpError() << "unexpected type `" << operand.getType()
1751  << "' for operand #" << pair.index();
1752  }
1753  return success();
1754 }
1755 
1756 //===----------------------------------------------------------------------===//
1757 // GPUModuleOp
1758 //===----------------------------------------------------------------------===//
1759 
1760 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1761  StringRef name, ArrayAttr targets,
1762  Attribute offloadingHandler) {
1763  result.addRegion()->emplaceBlock();
1764  Properties &props = result.getOrAddProperties<Properties>();
1765  if (targets)
1766  props.targets = targets;
1767  props.setSymName(builder.getStringAttr(name));
1768  props.offloadingHandler = offloadingHandler;
1769 }
1770 
1771 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1772  StringRef name, ArrayRef<Attribute> targets,
1773  Attribute offloadingHandler) {
1774  build(builder, result, name,
1775  targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1776  offloadingHandler);
1777 }
1778 
1779 bool GPUModuleOp::hasTarget(Attribute target) {
1780  if (ArrayAttr targets = getTargetsAttr())
1781  return llvm::count(targets.getValue(), target);
1782  return false;
1783 }
1784 
1785 void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1786  ArrayAttr &targetsAttr = getProperties().targets;
1787  SmallVector<Attribute> targetsVector(targets);
1788  targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1789 }
1790 
1791 //===----------------------------------------------------------------------===//
1792 // GPUBinaryOp
1793 //===----------------------------------------------------------------------===//
1794 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1795  Attribute offloadingHandler, ArrayAttr objects) {
1796  auto &properties = result.getOrAddProperties<Properties>();
1797  result.attributes.push_back(builder.getNamedAttr(
1798  SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1799  properties.objects = objects;
1800  if (offloadingHandler)
1801  properties.offloadingHandler = offloadingHandler;
1802  else
1803  properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
1804 }
1805 
1806 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1807  Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1808  build(builder, result, name, offloadingHandler,
1809  objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
1810 }
1811 
1812 static ParseResult parseOffloadingHandler(OpAsmParser &parser,
1813  Attribute &offloadingHandler) {
1814  if (succeeded(parser.parseOptionalLess())) {
1815  if (parser.parseAttribute(offloadingHandler))
1816  return failure();
1817  if (parser.parseGreater())
1818  return failure();
1819  }
1820  if (!offloadingHandler)
1821  offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
1822  return success();
1823 }
1824 
1826  Attribute offloadingHandler) {
1827  if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
1828  printer << '<' << offloadingHandler << '>';
1829 }
1830 
1831 //===----------------------------------------------------------------------===//
1832 // GPUMemcpyOp
1833 //===----------------------------------------------------------------------===//
1834 
1835 LogicalResult MemcpyOp::verify() {
1836  auto srcType = getSrc().getType();
1837  auto dstType = getDst().getType();
1838 
1839  if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1840  return emitOpError("arguments have incompatible element type");
1841 
1842  if (failed(verifyCompatibleShape(srcType, dstType)))
1843  return emitOpError("arguments have incompatible shape");
1844 
1845  return success();
1846 }
1847 
1848 namespace {
1849 
1850 /// Erases a common case of copy ops where a destination value is used only by
1851 /// the copy op, alloc and dealloc ops.
1852 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1853  using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1854 
1855  LogicalResult matchAndRewrite(MemcpyOp op,
1856  PatternRewriter &rewriter) const override {
1857  Value dest = op.getDst();
1858  Operation *destDefOp = dest.getDefiningOp();
1859  // `dest` must be defined by an op having Allocate memory effect in order to
1860  // perform the folding.
1861  if (!destDefOp ||
1862  !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1863  return failure();
1864  // We can erase `op` iff `dest` has no other use apart from its
1865  // use by `op` and dealloc ops.
1866  if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1867  return user != op &&
1868  !hasSingleEffect<MemoryEffects::Free>(user, dest);
1869  }))
1870  return failure();
1871  // We can perform the folding if and only if op has a single async
1872  // dependency and produces an async token as result, or if it does not have
1873  // any async dependency and does not produce any async token result.
1874  if (op.getAsyncDependencies().size() > 1 ||
1875  ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1876  (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1877  return failure();
1878  rewriter.replaceOp(op, op.getAsyncDependencies());
1879  return success();
1880  }
1881 };
1882 
1883 } // end anonymous namespace
1884 
1885 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1886  MLIRContext *context) {
1887  results.add<EraseTrivialCopyOp>(context);
1888 }
1889 
1890 //===----------------------------------------------------------------------===//
1891 // GPU_SubgroupMmaLoadMatrixOp
1892 //===----------------------------------------------------------------------===//
1893 
1894 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1895  auto srcType = getSrcMemref().getType();
1896  auto resType = getRes().getType();
1897  auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
1898  auto operand = resMatrixType.getOperand();
1899  auto srcMemrefType = llvm::cast<MemRefType>(srcType);
1900 
1901  if (!isLastMemrefDimUnitStride(srcMemrefType))
1902  return emitError(
1903  "expected source memref most minor dim must have unit stride");
1904 
1905  if (operand != "AOp" && operand != "BOp" && operand != "COp")
1906  return emitError("only AOp, BOp and COp can be loaded");
1907 
1908  return success();
1909 }
1910 
1911 //===----------------------------------------------------------------------===//
1912 // GPU_SubgroupMmaStoreMatrixOp
1913 //===----------------------------------------------------------------------===//
1914 
1915 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1916  auto srcType = getSrc().getType();
1917  auto dstType = getDstMemref().getType();
1918  auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
1919  auto dstMemrefType = llvm::cast<MemRefType>(dstType);
1920 
1921  if (!isLastMemrefDimUnitStride(dstMemrefType))
1922  return emitError(
1923  "expected destination memref most minor dim must have unit stride");
1924 
1925  if (srcMatrixType.getOperand() != "COp")
1926  return emitError(
1927  "expected the operand matrix being stored to have 'COp' operand type");
1928 
1929  return success();
1930 }
1931 
1932 //===----------------------------------------------------------------------===//
1933 // GPU_SubgroupMmaComputeOp
1934 //===----------------------------------------------------------------------===//
1935 
1936 LogicalResult SubgroupMmaComputeOp::verify() {
1937  enum OperandMap { A, B, C };
1938  SmallVector<MMAMatrixType, 3> opTypes;
1939  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
1940  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
1941  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
1942 
1943  if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
1944  opTypes[C].getOperand() != "COp")
1945  return emitError("operands must be in the order AOp, BOp, COp");
1946 
1947  ArrayRef<int64_t> aShape, bShape, cShape;
1948  aShape = opTypes[A].getShape();
1949  bShape = opTypes[B].getShape();
1950  cShape = opTypes[C].getShape();
1951 
1952  if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1953  bShape[1] != cShape[1])
1954  return emitError("operand shapes do not satisfy matmul constraints");
1955 
1956  return success();
1957 }
1958 
1959 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
1960  SmallVectorImpl<::mlir::OpFoldResult> &results) {
1961  return memref::foldMemRefCast(*this);
1962 }
1963 
1964 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
1965  SmallVectorImpl<::mlir::OpFoldResult> &results) {
1966  return memref::foldMemRefCast(*this);
1967 }
1968 
1969 //===----------------------------------------------------------------------===//
1970 // GPU_WaitOp
1971 //===----------------------------------------------------------------------===//
1972 
1973 namespace {
1974 
1975 /// Remove gpu.wait op use of gpu.wait op def without async dependencies.
1976 /// %t = gpu.wait async [] // No async dependencies.
1977 /// ... gpu.wait ... [%t, ...] // %t can be removed.
1978 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
1979 public:
1980  using OpRewritePattern::OpRewritePattern;
1981 
1982  LogicalResult matchAndRewrite(WaitOp op,
1983  PatternRewriter &rewriter) const final {
1984  auto predicate = [](Value value) {
1985  auto waitOp = value.getDefiningOp<WaitOp>();
1986  return waitOp && waitOp->getNumOperands() == 0;
1987  };
1988  if (llvm::none_of(op.getAsyncDependencies(), predicate))
1989  return failure();
1990  SmallVector<Value> validOperands;
1991  for (Value operand : op->getOperands()) {
1992  if (predicate(operand))
1993  continue;
1994  validOperands.push_back(operand);
1995  }
1996  rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
1997  return success();
1998  }
1999 };
2000 
2001 /// Simplify trivial gpu.wait ops for the following patterns.
2002 /// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2003 /// dependencies).
2004 /// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2005 /// %t0.
2006 /// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2007 /// dependencies nor return any token.
2008 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2009 public:
2010  using OpRewritePattern::OpRewritePattern;
2011 
2012  LogicalResult matchAndRewrite(WaitOp op,
2013  PatternRewriter &rewriter) const final {
2014  // Erase gpu.wait ops that neither have any async dependencies nor return
2015  // any async token.
2016  if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2017  rewriter.eraseOp(op);
2018  return success();
2019  }
2020  // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2021  if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2022  op.getAsyncToken()) {
2023  rewriter.replaceOp(op, op.getAsyncDependencies());
2024  return success();
2025  }
2026  // Erase %t = gpu.wait async ... ops, where %t has no uses.
2027  if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2028  rewriter.eraseOp(op);
2029  return success();
2030  }
2031  return failure();
2032  }
2033 };
2034 
2035 } // end anonymous namespace
2036 
2037 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2038  MLIRContext *context) {
2039  results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2040 }
2041 
2042 //===----------------------------------------------------------------------===//
2043 // GPU_AllocOp
2044 //===----------------------------------------------------------------------===//
2045 
2046 LogicalResult AllocOp::verify() {
2047  auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2048 
2049  if (getDynamicSizes().size() != memRefType.getNumDynamicDims())
2050  return emitOpError("dimension operand count does not equal memref "
2051  "dynamic dimension count");
2052 
2053  unsigned numSymbols = 0;
2054  if (!memRefType.getLayout().isIdentity())
2055  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2056  if (getSymbolOperands().size() != numSymbols) {
2057  return emitOpError(
2058  "symbol operand count does not equal memref symbol count");
2059  }
2060 
2061  return success();
2062 }
2063 
2064 namespace {
2065 
2066 /// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2067 /// `memref::AllocOp`.
2068 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2069  using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2070 
2071  LogicalResult matchAndRewrite(memref::DimOp dimOp,
2072  PatternRewriter &rewriter) const override {
2073  std::optional<int64_t> index = dimOp.getConstantIndex();
2074  if (!index)
2075  return failure();
2076 
2077  auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2078  if (!memrefType || !memrefType.isDynamicDim(index.value()))
2079  return failure();
2080 
2081  auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2082  if (!alloc)
2083  return failure();
2084 
2085  Value substituteOp = *(alloc.getDynamicSizes().begin() +
2086  memrefType.getDynamicDimIndex(index.value()));
2087  rewriter.replaceOp(dimOp, substituteOp);
2088  return success();
2089  }
2090 };
2091 
2092 } // namespace
2093 
2094 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2095  MLIRContext *context) {
2096  results.add<SimplifyDimOfAllocOp>(context);
2097 }
2098 
2099 //===----------------------------------------------------------------------===//
2100 // GPU object attribute
2101 //===----------------------------------------------------------------------===//
2102 
2103 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2104  Attribute target, CompilationTarget format,
2105  StringAttr object, DictionaryAttr properties,
2106  KernelTableAttr kernels) {
2107  if (!target)
2108  return emitError() << "the target attribute cannot be null";
2109  if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2110  return success();
2111  return emitError() << "the target attribute must implement or promise the "
2112  "`gpu::TargetAttrInterface`";
2113 }
2114 
2115 namespace {
2116 LogicalResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2117  StringAttr &object) {
2118  std::optional<CompilationTarget> formatResult;
2119  StringRef enumKeyword;
2120  auto loc = odsParser.getCurrentLocation();
2121  if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2122  formatResult = CompilationTarget::Fatbin;
2123  if (!formatResult &&
2124  (formatResult =
2125  gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2126  odsParser.parseEqual())
2127  return odsParser.emitError(loc, "expected an equal sign");
2128  if (!formatResult)
2129  return odsParser.emitError(loc, "expected keyword for GPU object format");
2130  FailureOr<StringAttr> objectResult =
2131  FieldParser<StringAttr>::parse(odsParser);
2132  if (failed(objectResult))
2133  return odsParser.emitError(odsParser.getCurrentLocation(),
2134  "failed to parse GPU_ObjectAttr parameter "
2135  "'object' which is to be a `StringAttr`");
2136  format = *formatResult;
2137  object = *objectResult;
2138  return success();
2139 }
2140 
2141 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2142  StringAttr object) {
2143  if (format != CompilationTarget::Fatbin)
2144  odsParser << stringifyEnum(format) << " = ";
2145  odsParser << object;
2146 }
2147 } // namespace
2148 
2149 //===----------------------------------------------------------------------===//
2150 // GPU select object attribute
2151 //===----------------------------------------------------------------------===//
2152 
2153 LogicalResult
2154 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2155  Attribute target) {
2156  // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2157  if (target) {
2158  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2159  if (intAttr.getInt() < 0) {
2160  return emitError() << "the object index must be positive";
2161  }
2162  } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2163  return emitError()
2164  << "the target attribute must be a GPU Target attribute";
2165  }
2166  }
2167  return success();
2168 }
2169 
2170 //===----------------------------------------------------------------------===//
2171 // DynamicSharedMemoryOp
2172 //===----------------------------------------------------------------------===//
2173 
2174 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2175  if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2176  return emitOpError() << "must be inside an op with symbol table";
2177 
2178  MemRefType memrefType = getResultMemref().getType();
2179  // Check address space
2180  if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2181  return emitOpError() << "address space must be "
2182  << gpu::AddressSpaceAttr::getMnemonic() << "<"
2183  << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2184  }
2185  if (memrefType.hasStaticShape()) {
2186  return emitOpError() << "result memref type must be memref<?xi8, "
2187  "#gpu.address_space<workgroup>>";
2188  }
2189  return success();
2190 }
2191 
2192 //===----------------------------------------------------------------------===//
2193 // GPU WarpExecuteOnLane0Op
2194 //===----------------------------------------------------------------------===//
2195 
2196 void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2197  p << "(" << getLaneid() << ")";
2198 
2199  SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2200  auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2201  p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2202 
2203  if (!getArgs().empty())
2204  p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2205  if (!getResults().empty())
2206  p << " -> (" << getResults().getTypes() << ')';
2207  p << " ";
2208  p.printRegion(getRegion(),
2209  /*printEntryBlockArgs=*/true,
2210  /*printBlockTerminators=*/!getResults().empty());
2211  p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2212 }
2213 
2214 ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2215  OperationState &result) {
2216  // Create the region.
2217  result.regions.reserve(1);
2218  Region *warpRegion = result.addRegion();
2219 
2220  auto &builder = parser.getBuilder();
2221  OpAsmParser::UnresolvedOperand laneId;
2222 
2223  // Parse predicate operand.
2224  if (parser.parseLParen() ||
2225  parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2226  parser.parseRParen())
2227  return failure();
2228 
2229  int64_t warpSize;
2230  if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2231  parser.parseRSquare())
2232  return failure();
2233  result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2234  builder.getContext())),
2235  builder.getI64IntegerAttr(warpSize));
2236 
2237  if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2238  return failure();
2239 
2240  llvm::SMLoc inputsOperandsLoc;
2241  SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2242  SmallVector<Type> inputTypes;
2243  if (succeeded(parser.parseOptionalKeyword("args"))) {
2244  if (parser.parseLParen())
2245  return failure();
2246 
2247  inputsOperandsLoc = parser.getCurrentLocation();
2248  if (parser.parseOperandList(inputsOperands) ||
2249  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2250  return failure();
2251  }
2252  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2253  result.operands))
2254  return failure();
2255 
2256  // Parse optional results type list.
2257  if (parser.parseOptionalArrowTypeList(result.types))
2258  return failure();
2259  // Parse the region.
2260  if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2261  /*argTypes=*/{}))
2262  return failure();
2263  WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2264 
2265  // Parse the optional attribute list.
2266  if (parser.parseOptionalAttrDict(result.attributes))
2267  return failure();
2268  return success();
2269 }
2270 
2271 void WarpExecuteOnLane0Op::getSuccessorRegions(
2272  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2273  if (!point.isParent()) {
2274  regions.push_back(RegionSuccessor(getResults()));
2275  return;
2276  }
2277 
2278  // The warp region is always executed
2279  regions.push_back(RegionSuccessor(&getWarpRegion()));
2280 }
2281 
2282 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2283  TypeRange resultTypes, Value laneId,
2284  int64_t warpSize) {
2285  build(builder, result, resultTypes, laneId, warpSize,
2286  /*operands=*/std::nullopt, /*argTypes=*/std::nullopt);
2287 }
2288 
2289 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2290  TypeRange resultTypes, Value laneId,
2291  int64_t warpSize, ValueRange args,
2292  TypeRange blockArgTypes) {
2293  result.addOperands(laneId);
2294  result.addAttribute(getAttributeNames()[0],
2295  builder.getI64IntegerAttr(warpSize));
2296  result.addTypes(resultTypes);
2297  result.addOperands(args);
2298  assert(args.size() == blockArgTypes.size());
2299  OpBuilder::InsertionGuard guard(builder);
2300  Region *warpRegion = result.addRegion();
2301  Block *block = builder.createBlock(warpRegion);
2302  for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2303  block->addArgument(type, arg.getLoc());
2304 }
2305 
2306 /// Helper check if the distributed vector type is consistent with the expanded
2307 /// type and distributed size.
2308 static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2309  int64_t warpSize, Operation *op) {
2310  // If the types matches there is no distribution.
2311  if (expanded == distributed)
2312  return success();
2313  auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2314  auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2315  if (!expandedVecType || !distributedVecType)
2316  return op->emitOpError("expected vector type for distributed operands.");
2317  if (expandedVecType.getRank() != distributedVecType.getRank() ||
2318  expandedVecType.getElementType() != distributedVecType.getElementType())
2319  return op->emitOpError(
2320  "expected distributed vectors to have same rank and element type.");
2321 
2322  SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2323  for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2324  int64_t eDim = expandedVecType.getDimSize(i);
2325  int64_t dDim = distributedVecType.getDimSize(i);
2326  if (eDim == dDim)
2327  continue;
2328  if (eDim % dDim != 0)
2329  return op->emitOpError()
2330  << "expected expanded vector dimension #" << i << " (" << eDim
2331  << ") to be a multipler of the distributed vector dimension ("
2332  << dDim << ")";
2333  scales[i] = eDim / dDim;
2334  }
2335  if (std::accumulate(scales.begin(), scales.end(), 1,
2336  std::multiplies<int64_t>()) != warpSize)
2337  return op->emitOpError()
2338  << "incompatible distribution dimensions from " << expandedVecType
2339  << " to " << distributedVecType << " with warp size = " << warpSize;
2340 
2341  return success();
2342 }
2343 
2344 LogicalResult WarpExecuteOnLane0Op::verify() {
2345  if (getArgs().size() != getWarpRegion().getNumArguments())
2346  return emitOpError(
2347  "expected same number op arguments and block arguments.");
2348  auto yield =
2349  cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
2350  if (yield.getNumOperands() != getNumResults())
2351  return emitOpError(
2352  "expected same number of yield operands and return values.");
2353  int64_t warpSize = getWarpSize();
2354  for (auto [regionArg, arg] :
2355  llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2356  if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2357  warpSize, getOperation())))
2358  return failure();
2359  }
2360  for (auto [yieldOperand, result] :
2361  llvm::zip_equal(yield.getOperands(), getResults())) {
2362  if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2363  warpSize, getOperation())))
2364  return failure();
2365  }
2366  return success();
2367 }
2368 bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2369  return succeeded(
2370  verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2371 }
2372 
2373 //===----------------------------------------------------------------------===//
2374 // GPU KernelMetadataAttr
2375 //===----------------------------------------------------------------------===//
2376 
2377 KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2378  DictionaryAttr metadata) {
2379  assert(kernel && "invalid kernel");
2380  return get(kernel.getNameAttr(), kernel.getFunctionType(),
2381  kernel.getAllArgAttrs(), metadata);
2382 }
2383 
2384 KernelMetadataAttr
2385 KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2386  FunctionOpInterface kernel,
2387  DictionaryAttr metadata) {
2388  assert(kernel && "invalid kernel");
2389  return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2390  kernel.getAllArgAttrs(), metadata);
2391 }
2392 
2393 KernelMetadataAttr
2394 KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2395  if (attrs.empty())
2396  return *this;
2397  NamedAttrList attrList;
2398  if (DictionaryAttr dict = getMetadata())
2399  attrList.append(dict);
2400  attrList.append(attrs);
2401  return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2402  attrList.getDictionary(getContext()));
2403 }
2404 
2405 LogicalResult
2406 KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2407  StringAttr name, Type functionType,
2408  ArrayAttr argAttrs, DictionaryAttr metadata) {
2409  if (name.empty())
2410  return emitError() << "the kernel name can't be empty";
2411  if (argAttrs) {
2412  if (llvm::any_of(argAttrs, [](Attribute attr) {
2413  return !llvm::isa<DictionaryAttr>(attr);
2414  }))
2415  return emitError()
2416  << "all attributes in the array must be a dictionary attribute";
2417  }
2418  return success();
2419 }
2420 
2421 //===----------------------------------------------------------------------===//
2422 // GPU KernelTableAttr
2423 //===----------------------------------------------------------------------===//
2424 
2425 KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2426  ArrayRef<KernelMetadataAttr> kernels,
2427  bool isSorted) {
2428  // Note that `is_sorted` is always only invoked once even with assertions ON.
2429  assert((!isSorted || llvm::is_sorted(kernels)) &&
2430  "expected a sorted kernel array");
2431  // Immediately return the attribute if the array is sorted.
2432  if (isSorted || llvm::is_sorted(kernels))
2433  return Base::get(context, kernels);
2434  // Sort the array.
2435  SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2436  llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2437  return Base::get(context, kernelsTmp);
2438 }
2439 
2440 KernelTableAttr KernelTableAttr::getChecked(
2441  function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2442  ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2443  // Note that `is_sorted` is always only invoked once even with assertions ON.
2444  assert((!isSorted || llvm::is_sorted(kernels)) &&
2445  "expected a sorted kernel array");
2446  // Immediately return the attribute if the array is sorted.
2447  if (isSorted || llvm::is_sorted(kernels))
2448  return Base::getChecked(emitError, context, kernels);
2449  // Sort the array.
2450  SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2451  llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2452  return Base::getChecked(emitError, context, kernelsTmp);
2453 }
2454 
2455 LogicalResult
2456 KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2457  ArrayRef<KernelMetadataAttr> kernels) {
2458  if (kernels.size() < 2)
2459  return success();
2460  // Check that the kernels are uniquely named.
2461  if (std::adjacent_find(kernels.begin(), kernels.end(),
2462  [](KernelMetadataAttr l, KernelMetadataAttr r) {
2463  return l.getName() == r.getName();
2464  }) != kernels.end()) {
2465  return emitError() << "expected all kernels to be uniquely named";
2466  }
2467  return success();
2468 }
2469 
2470 KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2471  auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2472  return found ? *iterator : KernelMetadataAttr();
2473 }
2474 
2475 KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2476  auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2477  return found ? *iterator : KernelMetadataAttr();
2478 }
2479 
2480 //===----------------------------------------------------------------------===//
2481 // GPU target options
2482 //===----------------------------------------------------------------------===//
2483 
2484 TargetOptions::TargetOptions(
2485  StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2486  StringRef cmdOptions, CompilationTarget compilationTarget,
2487  function_ref<SymbolTable *()> getSymbolTableCallback,
2488  function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2489  function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2490  function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2491  function_ref<void(StringRef)> isaCallback)
2492  : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
2493  cmdOptions, compilationTarget, getSymbolTableCallback,
2494  initialLlvmIRCallback, linkedLlvmIRCallback,
2495  optimizedLlvmIRCallback, isaCallback) {}
2496 
2497 TargetOptions::TargetOptions(
2498  TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2499  StringRef cmdOptions, CompilationTarget compilationTarget,
2500  function_ref<SymbolTable *()> getSymbolTableCallback,
2501  function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2502  function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2503  function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2504  function_ref<void(StringRef)> isaCallback)
2505  : toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
2506  cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
2507  getSymbolTableCallback(getSymbolTableCallback),
2508  initialLlvmIRCallback(initialLlvmIRCallback),
2509  linkedLlvmIRCallback(linkedLlvmIRCallback),
2510  optimizedLlvmIRCallback(optimizedLlvmIRCallback),
2511  isaCallback(isaCallback), typeID(typeID) {}
2512 
2513 TypeID TargetOptions::getTypeID() const { return typeID; }
2514 
2515 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2516 
2517 ArrayRef<std::string> TargetOptions::getLinkFiles() const { return linkFiles; }
2518 
2519 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2520 
2521 SymbolTable *TargetOptions::getSymbolTable() const {
2522  return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2523 }
2524 
2525 function_ref<void(llvm::Module &)>
2526 TargetOptions::getInitialLlvmIRCallback() const {
2527  return initialLlvmIRCallback;
2528 }
2529 
2530 function_ref<void(llvm::Module &)>
2531 TargetOptions::getLinkedLlvmIRCallback() const {
2532  return linkedLlvmIRCallback;
2533 }
2534 
2535 function_ref<void(llvm::Module &)>
2536 TargetOptions::getOptimizedLlvmIRCallback() const {
2537  return optimizedLlvmIRCallback;
2538 }
2539 
2540 function_ref<void(StringRef)> TargetOptions::getISACallback() const {
2541  return isaCallback;
2542 }
2543 
2544 CompilationTarget TargetOptions::getCompilationTarget() const {
2545  return compilationTarget;
2546 }
2547 
2548 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2549  return CompilationTarget::Fatbin;
2550 }
2551 
2552 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2553 TargetOptions::tokenizeCmdOptions() const {
2554  std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2555  llvm::StringSaver stringSaver(options.first);
2556  StringRef opts = cmdOptions;
2557  // For a correct tokenization of the command line options `opts` must be
2558  // unquoted, otherwise the tokenization function returns a single string: the
2559  // unquoted `cmdOptions` -which is not the desired behavior.
2560  // Remove any quotes if they are at the beginning and end of the string:
2561  if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2562  opts.consume_front("\""), opts.consume_back("\"");
2563  if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2564  opts.consume_front("'"), opts.consume_back("'");
2565 #ifdef _WIN32
2566  llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2567  /*MarkEOLs=*/false);
2568 #else
2569  llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2570  /*MarkEOLs=*/false);
2571 #endif // _WIN32
2572  return options;
2573 }
2574 
2576 
2577 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2578 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2579 
2580 #define GET_ATTRDEF_CLASSES
2581 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2582 
2583 #define GET_OP_CLASSES
2584 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2585 
2586 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
Definition: GPUDialect.cpp:422
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
Definition: GPUDialect.cpp:586
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
Definition: GPUDialect.cpp:438
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
Definition: GPUDialect.cpp:897
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)
Prints a GPU function memory attribution.
Definition: GPUDialect.cpp:472
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
Definition: GPUDialect.cpp:564
static std::string getSparseHandleKeyword(SparseHandleKind kind)
Definition: GPUDialect.cpp:222
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
Definition: GPUDialect.cpp:312
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
Definition: GPUDialect.cpp:599
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
Definition: GPUDialect.cpp:461
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
Definition: GPUDialect.cpp:510
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
Definition: GPUDialect.cpp:848
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
Definition: GPUDialect.cpp:484
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MINUI(lhs, rhs)
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:263
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
UnitAttr getUnitAttr()
Definition: Builders.cpp:138
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:240
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:203
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:120
IntegerType getI32Type()
Definition: Builders.cpp:107
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:99
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:302
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
IndexType getIndexType()
Definition: Builders.cpp:95
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition: Builders.cpp:144
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:134
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:106
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:49
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:221
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
Definition: Operation.cpp:256
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:87
bool isIndex() const
Definition: Types.cpp:64
bool isF32() const
Definition: Types.cpp:59
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:99
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:66
bool isF16() const
Definition: Types.cpp:57
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static ConcreteT get(MLIRContext *ctx, Args &&...args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Definition: GPUDialect.cpp:137
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:122
Type getElementType() const
Get elementType of a single element.
Definition: GPUDialect.cpp:141
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
Definition: GPUDialect.cpp:145
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
Definition: GPUDialect.cpp:152
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Definition: GPUDialect.cpp:143
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
Definition: GPUDialect.cpp:128
unsigned getNumDims() const
Get number of dims.
Definition: GPUDialect.cpp:135
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
Definition: GPUDialect.cpp:663
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:473
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:39