MLIR  17.0.0git
GPUOpsLowering.cpp
Go to the documentation of this file.
1 //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "GPUOpsLowering.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "llvm/ADT/STLExtras.h"
14 #include "llvm/Support/FormatVariadic.h"
15 
16 using namespace mlir;
17 
19 GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
20  ConversionPatternRewriter &rewriter) const {
21  Location loc = gpuFuncOp.getLoc();
22 
23  SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
24  workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
25  for (const auto &en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
26  Value attribution = en.value();
27 
28  auto type = attribution.getType().dyn_cast<MemRefType>();
29  assert(type && type.hasStaticShape() && "unexpected type in attribution");
30 
31  uint64_t numElements = type.getNumElements();
32 
33  auto elementType =
34  typeConverter->convertType(type.getElementType()).template cast<Type>();
35  auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
36  std::string name = std::string(
37  llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
38  auto globalOp = rewriter.create<LLVM::GlobalOp>(
39  gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
40  LLVM::Linkage::Internal, name, /*value=*/Attribute(),
41  /*alignment=*/0, workgroupAddrSpace);
42  workgroupBuffers.push_back(globalOp);
43  }
44 
45  // Rewrite the original GPU function to an LLVM function.
46  auto convertedType = typeConverter->convertType(gpuFuncOp.getFunctionType());
47  if (!convertedType)
48  return failure();
49  auto funcType =
50  convertedType.template cast<LLVM::LLVMPointerType>().getElementType();
51 
52  // Remap proper input types.
53  TypeConverter::SignatureConversion signatureConversion(
54  gpuFuncOp.front().getNumArguments());
56  gpuFuncOp.getFunctionType(), /*isVariadic=*/false, signatureConversion);
57 
58  // Create the new function operation. Only copy those attributes that are
59  // not specific to function modeling.
61  for (const auto &attr : gpuFuncOp->getAttrs()) {
62  if (attr.getName() == SymbolTable::getSymbolAttrName() ||
63  attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
64  attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
65  continue;
66  attributes.push_back(attr);
67  }
68  // Add a dialect specific kernel attribute in addition to GPU kernel
69  // attribute. The former is necessary for further translation while the
70  // latter is expected by gpu.launch_func.
71  if (gpuFuncOp.isKernel())
72  attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
73  auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
74  gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
75  LLVM::Linkage::External, /*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C,
76  attributes);
77 
78  {
79  // Insert operations that correspond to converted workgroup and private
80  // memory attributions to the body of the function. This must operate on
81  // the original function, before the body region is inlined in the new
82  // function to maintain the relation between block arguments and the
83  // parent operation that assigns their semantics.
84  OpBuilder::InsertionGuard guard(rewriter);
85 
86  // Rewrite workgroup memory attributions to addresses of global buffers.
87  rewriter.setInsertionPointToStart(&gpuFuncOp.front());
88  unsigned numProperArguments = gpuFuncOp.getNumArguments();
89 
90  for (const auto &en : llvm::enumerate(workgroupBuffers)) {
91  LLVM::GlobalOp global = en.value();
92  Value address = rewriter.create<LLVM::AddressOfOp>(loc, global);
93  auto elementType =
94  global.getType().cast<LLVM::LLVMArrayType>().getElementType();
95  Value memory = rewriter.create<LLVM::GEPOp>(
96  loc, LLVM::LLVMPointerType::get(elementType, global.getAddrSpace()),
97  address, ArrayRef<LLVM::GEPArg>{0, 0});
98 
99  // Build a memref descriptor pointing to the buffer to plug with the
100  // existing memref infrastructure. This may use more registers than
101  // otherwise necessary given that memref sizes are fixed, but we can try
102  // and canonicalize that away later.
103  Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
104  auto type = attribution.getType().cast<MemRefType>();
106  rewriter, loc, *getTypeConverter(), type, memory);
107  signatureConversion.remapInput(numProperArguments + en.index(), descr);
108  }
109 
110  // Rewrite private memory attributions to alloca'ed buffers.
111  unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
112  auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
113  for (const auto &en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
114  Value attribution = en.value();
115  auto type = attribution.getType().cast<MemRefType>();
116  assert(type && type.hasStaticShape() && "unexpected type in attribution");
117 
118  // Explicitly drop memory space when lowering private memory
119  // attributions since NVVM models it as `alloca`s in the default
120  // memory space and does not support `alloca`s with addrspace(5).
121  auto ptrType = LLVM::LLVMPointerType::get(
122  typeConverter->convertType(type.getElementType())
123  .template cast<Type>(),
124  allocaAddrSpace);
125  Value numElements = rewriter.create<LLVM::ConstantOp>(
126  gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
127  Value allocated = rewriter.create<LLVM::AllocaOp>(
128  gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
130  rewriter, loc, *getTypeConverter(), type, allocated);
131  signatureConversion.remapInput(
132  numProperArguments + numWorkgroupAttributions + en.index(), descr);
133  }
134  }
135 
136  // Move the region to the new function, update the entry block signature.
137  rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
138  llvmFuncOp.end());
139  if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
140  &signatureConversion)))
141  return failure();
142 
143  // If bare memref pointers are being used, remap them back to memref
144  // descriptors This must be done after signature conversion to get rid of the
145  // unrealized casts.
146  if (getTypeConverter()->getOptions().useBarePtrCallConv) {
147  OpBuilder::InsertionGuard guard(rewriter);
148  rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front());
149  for (const auto &en : llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
150  auto memrefTy = en.value().dyn_cast<MemRefType>();
151  if (!memrefTy)
152  continue;
153  assert(memrefTy.hasStaticShape() &&
154  "Bare pointer convertion used with dynamically-shaped memrefs");
155  // Use a placeholder when replacing uses of the memref argument to prevent
156  // circular replacements.
157  auto remapping = signatureConversion.getInputMapping(en.index());
158  assert(remapping && remapping->size == 1 &&
159  "Type converter should produce 1-to-1 mapping for bare memrefs");
160  BlockArgument newArg =
161  llvmFuncOp.getBody().getArgument(remapping->inputNo);
162  auto placeholder = rewriter.create<LLVM::UndefOp>(
163  loc, getTypeConverter()->convertType(memrefTy));
164  rewriter.replaceUsesOfBlockArgument(newArg, placeholder);
166  rewriter, loc, *getTypeConverter(), memrefTy, newArg);
167  rewriter.replaceOp(placeholder, {desc});
168  }
169  }
170 
171  rewriter.eraseOp(gpuFuncOp);
172  return success();
173 }
174 
175 static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
176  const char formatStringPrefix[] = "printfFormat_";
177  // Get a unique global name.
178  unsigned stringNumber = 0;
179  SmallString<16> stringConstName;
180  do {
181  stringConstName.clear();
182  (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
183  } while (moduleOp.lookupSymbol(stringConstName));
184  return stringConstName;
185 }
186 
187 template <typename T>
188 static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
189  ConversionPatternRewriter &rewriter,
190  StringRef name,
191  LLVM::LLVMFunctionType type) {
192  LLVM::LLVMFuncOp ret;
193  if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
194  ConversionPatternRewriter::InsertionGuard guard(rewriter);
195  rewriter.setInsertionPointToStart(moduleOp.getBody());
196  ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
197  LLVM::Linkage::External);
198  }
199  return ret;
200 }
201 
203  gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
204  ConversionPatternRewriter &rewriter) const {
205  Location loc = gpuPrintfOp->getLoc();
206 
207  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
208  mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
209  mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
210  mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
211  // Note: this is the GPUModule op, not the ModuleOp that surrounds it
212  // This ensures that global constants and declarations are placed within
213  // the device code, not the host code
214  auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
215 
216  auto ocklBegin =
217  getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
218  LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
219  LLVM::LLVMFuncOp ocklAppendArgs;
220  if (!adaptor.getArgs().empty()) {
221  ocklAppendArgs = getOrDefineFunction(
222  moduleOp, loc, rewriter, "__ockl_printf_append_args",
223  LLVM::LLVMFunctionType::get(
224  llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64,
225  llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32}));
226  }
227  auto ocklAppendStringN = getOrDefineFunction(
228  moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
229  LLVM::LLVMFunctionType::get(
230  llvmI64,
231  {llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
232 
233  /// Start the printf hostcall
234  Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
235  auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
236  Value printfDesc = printfBeginCall.getResult();
237 
238  // Get a unique global name for the format.
239  SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
240 
241  llvm::SmallString<20> formatString(adaptor.getFormat());
242  formatString.push_back('\0'); // Null terminate for C
243  size_t formatStringSize = formatString.size_in_bytes();
244 
245  auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
246  LLVM::GlobalOp global;
247  {
249  rewriter.setInsertionPointToStart(moduleOp.getBody());
250  global = rewriter.create<LLVM::GlobalOp>(
251  loc, globalType,
252  /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
253  rewriter.getStringAttr(formatString));
254  }
255 
256  // Get a pointer to the format string's first element and pass it to printf()
257  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
258  Value stringStart = rewriter.create<LLVM::GEPOp>(
259  loc, i8Ptr, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
260  Value stringLen =
261  rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
262 
263  Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
264  Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
265 
266  auto appendFormatCall = rewriter.create<LLVM::CallOp>(
267  loc, ocklAppendStringN,
268  ValueRange{printfDesc, stringStart, stringLen,
269  adaptor.getArgs().empty() ? oneI32 : zeroI32});
270  printfDesc = appendFormatCall.getResult();
271 
272  // __ockl_printf_append_args takes 7 values per append call
273  constexpr size_t argsPerAppend = 7;
274  size_t nArgs = adaptor.getArgs().size();
275  for (size_t group = 0; group < nArgs; group += argsPerAppend) {
276  size_t bound = std::min(group + argsPerAppend, nArgs);
277  size_t numArgsThisCall = bound - group;
278 
280  arguments.push_back(printfDesc);
281  arguments.push_back(
282  rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
283  for (size_t i = group; i < bound; ++i) {
284  Value arg = adaptor.getArgs()[i];
285  if (auto floatType = arg.getType().dyn_cast<FloatType>()) {
286  if (!floatType.isF64())
287  arg = rewriter.create<LLVM::FPExtOp>(
288  loc, typeConverter->convertType(rewriter.getF64Type()), arg);
289  arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
290  }
291  if (arg.getType().getIntOrFloatBitWidth() != 64)
292  arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
293 
294  arguments.push_back(arg);
295  }
296  // Pad out to 7 arguments since the hostcall always needs 7
297  for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
298  arguments.push_back(zeroI64);
299  }
300 
301  auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
302  arguments.push_back(isLast);
303  auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
304  printfDesc = call.getResult();
305  }
306  rewriter.eraseOp(gpuPrintfOp);
307  return success();
308 }
309 
311  gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
312  ConversionPatternRewriter &rewriter) const {
313  Location loc = gpuPrintfOp->getLoc();
314 
315  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
316  mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8, addressSpace);
317 
318  // Note: this is the GPUModule op, not the ModuleOp that surrounds it
319  // This ensures that global constants and declarations are placed within
320  // the device code, not the host code
321  auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
322 
323  auto printfType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr},
324  /*isVarArg=*/true);
325  LLVM::LLVMFuncOp printfDecl =
326  getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
327 
328  // Get a unique global name for the format.
329  SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
330 
331  llvm::SmallString<20> formatString(adaptor.getFormat());
332  formatString.push_back('\0'); // Null terminate for C
333  auto globalType =
334  LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
335  LLVM::GlobalOp global;
336  {
338  rewriter.setInsertionPointToStart(moduleOp.getBody());
339  global = rewriter.create<LLVM::GlobalOp>(
340  loc, globalType,
341  /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
342  rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace);
343  }
344 
345  // Get a pointer to the format string's first element
346  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
347  Value stringStart = rewriter.create<LLVM::GEPOp>(
348  loc, i8Ptr, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
349 
350  // Construct arguments and function call
351  auto argsRange = adaptor.getArgs();
352  SmallVector<Value, 4> printfArgs;
353  printfArgs.reserve(argsRange.size() + 1);
354  printfArgs.push_back(stringStart);
355  printfArgs.append(argsRange.begin(), argsRange.end());
356 
357  rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
358  rewriter.eraseOp(gpuPrintfOp);
359  return success();
360 }
361 
363  gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
364  ConversionPatternRewriter &rewriter) const {
365  Location loc = gpuPrintfOp->getLoc();
366 
367  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
368  mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
369 
370  // Note: this is the GPUModule op, not the ModuleOp that surrounds it
371  // This ensures that global constants and declarations are placed within
372  // the device code, not the host code
373  auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
374 
375  auto vprintfType =
376  LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr, i8Ptr});
377  LLVM::LLVMFuncOp vprintfDecl =
378  getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
379 
380  // Get a unique global name for the format.
381  SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
382 
383  llvm::SmallString<20> formatString(adaptor.getFormat());
384  formatString.push_back('\0'); // Null terminate for C
385  auto globalType =
386  LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
387  LLVM::GlobalOp global;
388  {
390  rewriter.setInsertionPointToStart(moduleOp.getBody());
391  global = rewriter.create<LLVM::GlobalOp>(
392  loc, globalType,
393  /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
394  rewriter.getStringAttr(formatString), /*allignment=*/0);
395  }
396 
397  // Get a pointer to the format string's first element
398  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
399  Value stringStart = rewriter.create<LLVM::GEPOp>(
400  loc, i8Ptr, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
401  SmallVector<Type> types;
402  SmallVector<Value> args;
403  // Promote and pack the arguments into a stack allocation.
404  for (Value arg : adaptor.getArgs()) {
405  Type type = arg.getType();
406  Value promotedArg = arg;
407  assert(type.isIntOrFloat());
408  if (type.isa<FloatType>()) {
409  type = rewriter.getF64Type();
410  promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg);
411  }
412  types.push_back(type);
413  args.push_back(promotedArg);
414  }
415  Type structType =
416  LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
417  Type structPtrType = LLVM::LLVMPointerType::get(structType);
418  Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
419  rewriter.getIndexAttr(1));
420  Value tempAlloc = rewriter.create<LLVM::AllocaOp>(loc, structPtrType, one,
421  /*alignment=*/0);
422  for (auto [index, arg] : llvm::enumerate(args)) {
423  Value ptr = rewriter.create<LLVM::GEPOp>(
424  loc, LLVM::LLVMPointerType::get(arg.getType()), tempAlloc,
425  ArrayRef<LLVM::GEPArg>{0, index});
426  rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
427  }
428  tempAlloc = rewriter.create<LLVM::BitcastOp>(loc, i8Ptr, tempAlloc);
429  std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
430 
431  rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
432  rewriter.eraseOp(gpuPrintfOp);
433  return success();
434 }
435 
436 /// Unrolls op if it's operating on vectors.
438  ConversionPatternRewriter &rewriter,
439  LLVMTypeConverter &converter) {
440  TypeRange operandTypes(operands);
441  if (llvm::none_of(operandTypes,
442  [](Type type) { return type.isa<VectorType>(); })) {
443  return rewriter.notifyMatchFailure(op, "expected vector operand");
444  }
445  if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
446  return rewriter.notifyMatchFailure(op, "expected no region/successor");
447  if (op->getNumResults() != 1)
448  return rewriter.notifyMatchFailure(op, "expected single result");
449  VectorType vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
450  if (!vectorType)
451  return rewriter.notifyMatchFailure(op, "expected vector result");
452 
453  Location loc = op->getLoc();
454  Value result = rewriter.create<LLVM::UndefOp>(loc, vectorType);
455  Type indexType = converter.convertType(rewriter.getIndexType());
456  StringAttr name = op->getName().getIdentifier();
457  Type elementType = vectorType.getElementType();
458 
459  for (int64_t i = 0; i < vectorType.getNumElements(); ++i) {
460  Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i);
461  auto extractElement = [&](Value operand) -> Value {
462  if (!operand.getType().isa<VectorType>())
463  return operand;
464  return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index);
465  };
466  auto scalarOperands =
467  llvm::to_vector(llvm::map_range(operands, extractElement));
468  Operation *scalarOp =
469  rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs());
470  rewriter.create<LLVM::InsertElementOp>(loc, result, scalarOp->getResult(0),
471  index);
472  }
473 
474  rewriter.replaceOp(op, result);
475  return success();
476 }
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc, ConversionPatternRewriter &rewriter, StringRef name, LLVM::LLVMFunctionType type)
static SmallString< 16 > getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:698
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:304
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:109
UnitAttr getUnitAttr()
Definition: Builders.cpp:99
IntegerType getI64Type()
Definition: Builders.cpp:70
IntegerType getI32Type()
Definition: Builders.cpp:68
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:72
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:243
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:56
IntegerType getI8Type()
Definition: Builders.cpp:64
FloatType getF64Type()
Definition: Builders.cpp:50
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
TypeConverter * typeConverter
An optional type converter for use by this pattern.
LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:28
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, SignatureConversion &result)
Convert a function type.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:447
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:301
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:384
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
unsigned getNumSuccessors()
Definition: Operation.h:552
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:368
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:519
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:198
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:400
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:94
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:365
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:58
This class provides all of the information necessary to convert a type signature.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:321
U dyn_cast() const
Definition: Types.h:311
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:105
bool isa() const
Definition: Types.h:301
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:109
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:350
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
LogicalResult scalarizeVectorOp(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, LLVMTypeConverter &converter)
Unrolls op if it's operating on vectors.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26