MLIR  15.0.0git
NVGPUToNVVM.cpp
Go to the documentation of this file.
1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "../PassDetail.h"
16 
17 using namespace mlir;
18 
19 /// Returns the type for the intrinsic given the vectorResultType of the
20 /// `gpu.mma.sync` operation.
21 static Type inferIntrinsicResultType(Type vectorResultType) {
22  MLIRContext *ctx = vectorResultType.getContext();
23  auto a = vectorResultType.cast<LLVM::LLVMArrayType>();
24  auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
25  auto i32Ty = IntegerType::get(ctx, 32);
26  auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
27  Type f64Ty = Float64Type::get(ctx);
28  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
29  Type f32Ty = Float32Type::get(ctx);
30  Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
31  if (a.getElementType() == f16x2Ty) {
33  ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
34  }
35  if (a.getElementType() == i32x2Ty) {
37  ctx,
38  SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
39  }
40  if (a.getElementType() == f64x2Ty) {
41  return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
42  }
43  if (a.getElementType() == f32x2Ty) {
45  ctx,
46  SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
47  }
48  if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
50  ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
51  }
52  return vectorResultType;
53 }
54 
55 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
56 /// always an LLVM struct) into a fragment that is compatible with the vector
57 /// type of this operation. This involves extracting elements from the struct
58 /// and inserting them into an LLVM array. These extra data-movement
59 /// operations should be canonicalized away by the LLVM backend.
60 static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
61  Type resultType, Value intrinsicResult,
62  RewriterBase &rewriter) {
63  MLIRContext *ctx = rewriter.getContext();
64  auto structType = intrinsicResultType.dyn_cast<LLVM::LLVMStructType>();
65  auto arrayType = resultType.dyn_cast<LLVM::LLVMArrayType>();
66  Type i32Ty = rewriter.getI32Type();
67  Type f32Ty = rewriter.getF32Type();
68  Type f64Ty = rewriter.getF64Type();
69  Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
70  Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
71  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
72  Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
73  Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
74 
75  auto makeConst = [&](int32_t index) -> Value {
76  return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
77  rewriter.getI32IntegerAttr(index));
78  };
79 
80  if (arrayType) {
81  SmallVector<Value, 4> elements;
82 
83  // The intrinsic returns 32-bit wide elements in a form which can be
84  // directly bitcasted and inserted into the result vector.
85  if (arrayType.getElementType() == f16x2Ty ||
86  arrayType.getElementType() == f32x1Ty) {
87  for (unsigned i = 0; i < structType.getBody().size(); i++) {
88  Value el = rewriter.create<LLVM::ExtractValueOp>(
89  loc, structType.getBody()[i], intrinsicResult,
90  rewriter.getI64ArrayAttr(i));
91  el = rewriter.createOrFold<LLVM::BitcastOp>(
92  loc, arrayType.getElementType(), el);
93  elements.push_back(el);
94  }
95  }
96 
97  // The intrinsic returns i32, f64, and f32 values as individual scalars,
98  // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
99  // need to extract them from the struct and pack them into the 64-bit wide
100  // rows of the vector result.
101  if (arrayType.getElementType() == i32x2Ty ||
102  arrayType.getElementType() == f64x2Ty ||
103  arrayType.getElementType() == f32x2Ty) {
104 
105  for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
106  Value vec =
107  rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
108  Value x1 = rewriter.create<LLVM::ExtractValueOp>(
109  loc, structType.getBody()[i * 2], intrinsicResult,
110  rewriter.getI64ArrayAttr(i * 2));
111  Value x2 = rewriter.create<LLVM::ExtractValueOp>(
112  loc, structType.getBody()[i * 2 + 1], intrinsicResult,
113  rewriter.getI64ArrayAttr(i * 2 + 1));
114  vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
115  x1, makeConst(0));
116  vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
117  x2, makeConst(1));
118  elements.push_back(vec);
119  }
120  }
121 
122  // Create the final vectorized result.
123  Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
124  for (const auto &el : llvm::enumerate(elements)) {
125  result = rewriter.create<LLVM::InsertValueOp>(
126  loc, arrayType, result, el.value(),
127  rewriter.getI64ArrayAttr(el.index()));
128  }
129  return result;
130  }
131 
132  return intrinsicResult;
133 }
134 
135 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be
136 /// given as 2D `vectors` where the rows are 32b or 64b wide. The
137 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of
138 /// scalars of certain types. This function helps unpack the `vector` arguments
139 /// and cast them to the types expected by `nvvm.mma.sync`.
141  Location loc, Value operand,
142  NVVM::MMATypes operandPtxType) {
143  SmallVector<Value> result;
144  Type i32Ty = rewriter.getI32Type();
145  Type f64Ty = rewriter.getF64Type();
146  Type f32Ty = rewriter.getF32Type();
147  Type i8Ty = rewriter.getI8Type();
148  Type i4Ty = rewriter.getIntegerType(4);
149  Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
150  Type i4x8Ty = LLVM::getFixedVectorType(i4Ty, 8);
151  Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
152  auto arrayTy = operand.getType().cast<LLVM::LLVMArrayType>();
153 
154  for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
155  Value toUse = rewriter.create<LLVM::ExtractValueOp>(
156  loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i));
157 
158  // For 4xi8 vectors, the intrinsic expects these to be provided as i32
159  // scalar types.
160  if (arrayTy.getElementType() == i8x4Ty ||
161  arrayTy.getElementType() == i4x8Ty ||
162  (arrayTy.getElementType() == f32x1Ty &&
163  operandPtxType == NVVM::MMATypes::tf32)) {
164  result.push_back(
165  rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse));
166  continue;
167  }
168 
169  // For some element types (i32, f32, f64), we need to unpack the inner
170  // vector/array type as well because the intrinsic expects individual
171  // scalars to be provided.
172  VectorType innerArrayTy = arrayTy.getElementType().dyn_cast<VectorType>();
173  if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
174  innerArrayTy.getElementType() == f64Ty ||
175  innerArrayTy.getElementType() == f32Ty)) {
176  for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
177  idx < innerSize; idx++) {
178  result.push_back(rewriter.create<LLVM::ExtractElementOp>(
179  loc, toUse,
180  rewriter.create<LLVM::ConstantOp>(
181  loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx))));
182  }
183  continue;
184  }
185  result.push_back(toUse);
186  }
187  return result;
188 }
189 
190 namespace {
191 
192 struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
194 
196  matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
197  ConversionPatternRewriter &rewriter) const override {
198  MLIRContext *ctx = getContext();
199  Location loc = op->getLoc();
200 
201  // The result type of ldmatrix will always be a struct of 32bit integer
202  // registers if more than one 32bit value is returned. Otherwise, the result
203  // is a single i32. The result type of the GPU operation is always a vector
204  // of shape (NumRegisters, VectorRegister) where VectorRegister is the
205  // vector type of the result and always 32 bits long. We bitcast the result
206  // of the NVVM::LdMatrix to this vector type.
207  auto vectorResultType = op->getResultTypes()[0].dyn_cast<VectorType>();
208  if (!vectorResultType) {
209  return failure();
210  }
211  Type innerVectorType = LLVM::getFixedVectorType(
212  vectorResultType.getElementType(), vectorResultType.getDimSize(1));
213 
214  int64_t num32BitRegs = vectorResultType.getDimSize(0);
215 
216  Type ldMatrixResultType;
217  if (num32BitRegs > 1) {
218  ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
219  ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
220  } else {
221  ldMatrixResultType = rewriter.getI32Type();
222  }
223 
224  auto srcMemrefType = op.getSrcMemref().getType().cast<MemRefType>();
225  Value srcPtr =
226  getStridedElementPtr(loc, srcMemrefType, adaptor.getSrcMemref(),
227  adaptor.getIndices(), rewriter);
228  Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>(
229  loc, ldMatrixResultType, srcPtr,
230  /*num=*/op.getNumTiles(),
231  /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
232  : NVVM::MMALayout::row);
233 
234  // The ldmatrix operation returns either a single i32 value or a struct of
235  // i32 values. Here we unpack those values and cast them back to their
236  // actual vector type (still of width 32b) and repack them into a result
237  // struct.
238  Type finalResultType = typeConverter->convertType(vectorResultType);
239  Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType);
240  for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
241  Value i32Register = num32BitRegs > 1
242  ? rewriter.create<LLVM::ExtractValueOp>(
243  loc, rewriter.getI32Type(), ldMatrixResult,
244  rewriter.getI64ArrayAttr(i))
245  : ldMatrixResult;
246  Value casted =
247  rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register);
248  result = rewriter.create<LLVM::InsertValueOp>(
249  loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i));
250  }
251 
252  rewriter.replaceOp(op, result);
253  return success();
254  }
255 };
256 
257 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
259 
261  matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
262  ConversionPatternRewriter &rewriter) const override {
263  Location loc = op->getLoc();
264  // Get the shapes of the MMAMatrix type being used. The shapes will
265  // choose which intrinsic this op will be lowered to.
266  auto aType = op.getMatrixA().getType().cast<VectorType>();
267  auto cType = op.getMatrixC().getType().cast<VectorType>();
268 
269  int64_t m = op.getMmaShape()[0].cast<IntegerAttr>().getInt();
270  int64_t n = op.getMmaShape()[1].cast<IntegerAttr>().getInt();
271  int64_t k = op.getMmaShape()[2].cast<IntegerAttr>().getInt();
272  std::array<int64_t, 3> gemmShape{m, n, k};
273 
274  NVVM::MMATypes ptxTypeA;
275  NVVM::MMATypes ptxTypeB;
276  Optional<NVVM::MMATypes> ptxTypeC = NVVM::MmaOp::inferOperandMMAType(
277  cType.getElementType(), /*isAccumulator=*/true);
278  if (!ptxTypeC) {
279  return op->emitError(
280  "could not infer the PTX type for the accumulator/result");
281  }
282 
283  Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
284  if (aType.getElementType().isInteger(8)) {
285  ptxTypeA = NVVM::MMATypes::s8;
286  ptxTypeB = NVVM::MMATypes::s8;
287  overflow = NVVM::MMAIntOverflow::satfinite;
288  } else if (aType.getElementType().isInteger(4)) {
289  ptxTypeA = NVVM::MMATypes::s4;
290  ptxTypeB = NVVM::MMATypes::s4;
291  overflow = NVVM::MMAIntOverflow::satfinite;
292  } else if (aType.getElementType().isF16()) {
293  ptxTypeA = NVVM::MMATypes::f16;
294  ptxTypeB = NVVM::MMATypes::f16;
295  } else if (aType.getElementType().isF64()) {
296  ptxTypeA = NVVM::MMATypes::f64;
297  ptxTypeB = NVVM::MMATypes::f64;
298  } else if (aType.getElementType().isF32()) {
299  ptxTypeA = NVVM::MMATypes::tf32;
300  ptxTypeB = NVVM::MMATypes::tf32;
301  } else {
302  return op->emitError("could not deduce operand PTX types");
303  }
304 
305  SmallVector<Value> matA =
306  unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), ptxTypeA);
307  SmallVector<Value> matB =
308  unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), ptxTypeB);
309  SmallVector<Value> matC =
310  unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
311 
312  Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
313  Type intrinsicResTy = inferIntrinsicResultType(
314  typeConverter->convertType(op->getResultTypes()[0]));
315  Value intrinsicResult = rewriter.create<NVVM::MmaOp>(
316  op.getLoc(), intrinsicResTy, matA, matB, matC,
317  /*shape=*/gemmShape,
318  /*b1Op=*/llvm::None,
319  /*intOverflow=*/overflow,
320  /*multiplicandPtxTypes=*/
321  std::array<NVVM::MMATypes, 2>{ptxTypeA, ptxTypeB},
322  /*multiplicandLayouts=*/
323  std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
324  NVVM::MMALayout::col});
325  rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
326  desiredRetTy, intrinsicResult,
327  rewriter));
328  return success();
329  }
330 };
331 
332 struct ConvertNVGPUToNVVMPass
333  : public ConvertNVGPUToNVVMBase<ConvertNVGPUToNVVMPass> {
334  ConvertNVGPUToNVVMPass() = default;
335 
336  void runOnOperation() override {
337  RewritePatternSet patterns(&getContext());
338  LLVMTypeConverter converter(&getContext());
339  /// device-side async tokens cannot be materialized in nvvm. We just convert
340  /// them to a dummy i32 type in order to easily drop them during conversion.
341  converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
342  return converter.convertType(IntegerType::get(type.getContext(), 32));
343  });
344  populateNVGPUToNVVMConversionPatterns(converter, patterns);
345  LLVMConversionTarget target(getContext());
346  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
347  target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
348  if (failed(applyPartialConversion(getOperation(), target,
349  std::move(patterns))))
350  signalPassFailure();
351  }
352 };
353 
354 struct NVGPUAsyncCopyLowering
355  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
357  nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
358 
360  matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
361  ConversionPatternRewriter &rewriter) const override {
362  Location loc = op->getLoc();
363  auto dstMemrefType = op.getDst().getType().cast<MemRefType>();
364  Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(),
365  adaptor.getDstIndices(), rewriter);
366  auto i8Ty = IntegerType::get(op.getContext(), 8);
367  auto dstPointerType =
368  LLVM::LLVMPointerType::get(i8Ty, dstMemrefType.getMemorySpaceAsInt());
369  dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr);
370 
371  auto srcMemrefType = op.getSrc().getType().cast<MemRefType>();
372 
373  Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
374  adaptor.getSrcIndices(), rewriter);
375  auto srcPointerType =
376  LLVM::LLVMPointerType::get(i8Ty, srcMemrefType.getMemorySpaceAsInt());
377  scrPtr = rewriter.create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr);
378  // Intrinsics takes a global pointer so we need an address space cast.
379  auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
381  scrPtr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType,
382  scrPtr);
383  int64_t numElements = adaptor.getNumElements().getZExtValue();
384  int64_t sizeInBytes =
385  (dstMemrefType.getElementTypeBitWidth() * numElements) / 8;
386  // bypass L1 is only supported for byte sizes of 16, we drop the hint
387  // otherwise.
388  UnitAttr bypassL1 =
389  sizeInBytes == 16 ? adaptor.getBypassL1Attr() : UnitAttr();
390  rewriter.create<NVVM::CpAsyncOp>(
391  loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), bypassL1);
392 
393  // Drop the result token.
394  Value zero = rewriter.create<LLVM::ConstantOp>(
395  op->getLoc(), IntegerType::get(op.getContext(), 32),
396  rewriter.getI32IntegerAttr(0));
397  rewriter.replaceOp(op, zero);
398  return success();
399  }
400 };
401 
402 struct NVGPUAsyncCreateGroupLowering
403  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
405  nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
406 
408  matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
409  ConversionPatternRewriter &rewriter) const override {
410  rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
411  // Drop the result token.
412  Value zero = rewriter.create<LLVM::ConstantOp>(
413  op->getLoc(), IntegerType::get(op.getContext(), 32),
414  rewriter.getI32IntegerAttr(0));
415  rewriter.replaceOp(op, zero);
416  return success();
417  }
418 };
419 
420 struct NVGPUAsyncWaitLowering
421  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
423  nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
424 
426  matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
427  ConversionPatternRewriter &rewriter) const override {
428  // If numGroup is not present pick 0 as a conservative correct value.
429  int32_t numGroups = adaptor.getNumGroups() ? *adaptor.getNumGroups() : 0;
430  rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
431  rewriter.eraseOp(op);
432  return success();
433  }
434 };
435 
436 } // namespace
437 
439  RewritePatternSet &patterns) {
440  patterns.add<MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
441  NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering>(
442  converter);
443 }
444 
445 std::unique_ptr<Pass> mlir::createConvertNVGPUToNVVMPass() {
446  return std::make_unique<ConvertNVGPUToNVVMPass>();
447 }
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation...
Definition: NVGPUToNVVM.cpp:21
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:132
MLIRContext * getContext() const
Definition: Builders.h:54
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:458
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
Definition: LLVMTypes.cpp:908
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:414
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
std::unique_ptr< Pass > createConvertNVGPUToNVVMPass()
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
FloatType getF16Type()
Definition: Builders.cpp:38
Derived class that automatically populates legalization information for different LLVM ops...
FloatType getF32Type()
Definition: Builders.cpp:40
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
Definition: NVGPUToNVVM.cpp:60
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:220
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:148
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
U dyn_cast() const
Definition: Types.h:256
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:99
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:234
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
IntegerType getI8Type()
Definition: Builders.cpp:52
static LLVMPointerType get(MLIRContext *context, unsigned addressSpace=0)
Gets or creates an instance of LLVM dialect pointer type pointing to an object of pointee type in the...
Definition: LLVMTypes.cpp:188
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
Device-side token storage type. There is only one type of device-side token.
Definition: NVGPUDialect.h:25
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
LLVM dialect array type.
Definition: LLVMTypes.h:75
IntegerType getI64Type()
Definition: Builders.cpp:56
void populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getType() const
Return the type of this attribute.
Definition: Attributes.h:66
void addConversion(FnT &&callback)
Register a conversion function.
Type getType() const
Return the type of this value.
Definition: Value.h:118
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:277
U dyn_cast() const
Definition: Attributes.h:124
FloatType getF64Type()
Definition: Builders.cpp:42
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
Global memory space identifier.
Definition: GPUDialect.cpp:94
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
IntegerType getI32Type()
Definition: Builders.cpp:54
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:398
static SmallVector< Value > unpackOperandVector(RewriterBase &rewriter, Location loc, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
U cast() const
Definition: Types.h:262