MLIR  22.0.0git
FunctionCallUtils.cpp
Go to the documentation of this file.
1 //===- FunctionCallUtils.cpp - Utilities for C function calls -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements helper functions to call common simple C functions in
10 // LLVMIR (e.g. amon others to support printing and debugging).
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/Support/LLVM.h"
19 
20 using namespace mlir;
21 using namespace mlir::LLVM;
22 
23 /// Helper functions to lookup or create the declaration for commonly used
24 /// external C function calls. The list of functions provided here must be
25 /// implemented separately (e.g. as part of a support runtime library or as
26 /// part of the libc).
27 static constexpr llvm::StringRef kPrintI64 = "printI64";
28 static constexpr llvm::StringRef kPrintU64 = "printU64";
29 static constexpr llvm::StringRef kPrintF16 = "printF16";
30 static constexpr llvm::StringRef kPrintBF16 = "printBF16";
31 static constexpr llvm::StringRef kPrintF32 = "printF32";
32 static constexpr llvm::StringRef kPrintF64 = "printF64";
33 static constexpr llvm::StringRef kPrintString = "printString";
34 static constexpr llvm::StringRef kPrintOpen = "printOpen";
35 static constexpr llvm::StringRef kPrintClose = "printClose";
36 static constexpr llvm::StringRef kPrintComma = "printComma";
37 static constexpr llvm::StringRef kPrintNewline = "printNewline";
38 static constexpr llvm::StringRef kMalloc = "malloc";
39 static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc";
40 static constexpr llvm::StringRef kFree = "free";
41 static constexpr llvm::StringRef kGenericAlloc = "_mlir_memref_to_llvm_alloc";
42 static constexpr llvm::StringRef kGenericAlignedAlloc =
43  "_mlir_memref_to_llvm_aligned_alloc";
44 static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
45 static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
46 
47 namespace {
48 /// Search for an LLVMFuncOp with a given name within an operation with the
49 /// SymbolTable trait. An optional collection of cached symbol tables can be
50 /// given to avoid a linear scan of the symbol table operation.
51 LLVM::LLVMFuncOp lookupFuncOp(StringRef name, Operation *symbolTableOp,
52  SymbolTableCollection *symbolTables = nullptr) {
53  if (symbolTables) {
54  return symbolTables->lookupSymbolIn<LLVM::LLVMFuncOp>(
55  symbolTableOp, StringAttr::get(symbolTableOp->getContext(), name));
56  }
57 
58  return llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
59  SymbolTable::lookupSymbolIn(symbolTableOp, name));
60 }
61 } // namespace
62 
63 /// Generic print function lookupOrCreate helper.
64 FailureOr<LLVM::LLVMFuncOp>
65 mlir::LLVM::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
66  ArrayRef<Type> paramTypes, Type resultType,
67  bool isVarArg, bool isReserved,
68  SymbolTableCollection *symbolTables) {
69  assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
70  "expected SymbolTable operation");
71  auto func = lookupFuncOp(name, moduleOp, symbolTables);
72  auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg);
73  // Assert the signature of the found function is same as expected
74  if (func) {
75  if (funcT != func.getFunctionType()) {
76  if (isReserved) {
77  func.emitError("redefinition of reserved function '")
78  << name << "' of different type " << func.getFunctionType()
79  << " is prohibited";
80  } else {
81  func.emitError("redefinition of function '")
82  << name << "' of different type " << funcT << " is prohibited";
83  }
84  return failure();
85  }
86  return func;
87  }
88 
90  assert(!moduleOp->getRegion(0).empty() && "expected non-empty region");
91  b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
92  auto funcOp = LLVM::LLVMFuncOp::create(
93  b, moduleOp->getLoc(), name,
94  LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
95 
96  if (symbolTables) {
97  SymbolTable &symbolTable = symbolTables->getSymbolTable(moduleOp);
98  symbolTable.insert(funcOp, moduleOp->getRegion(0).front().begin());
99  }
100 
101  return funcOp;
102 }
103 
104 static FailureOr<LLVM::LLVMFuncOp>
105 lookupOrCreateReservedFn(OpBuilder &b, Operation *moduleOp, StringRef name,
106  ArrayRef<Type> paramTypes, Type resultType,
107  SymbolTableCollection *symbolTables) {
108  return lookupOrCreateFn(b, moduleOp, name, paramTypes, resultType,
109  /*isVarArg=*/false, /*isReserved=*/true,
110  symbolTables);
111 }
112 
113 FailureOr<LLVM::LLVMFuncOp>
115  SymbolTableCollection *symbolTables) {
117  b, moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
118  LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
119 }
120 
121 FailureOr<LLVM::LLVMFuncOp>
123  SymbolTableCollection *symbolTables) {
125  b, moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
126  LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
127 }
128 
129 FailureOr<LLVM::LLVMFuncOp>
131  SymbolTableCollection *symbolTables) {
133  b, moduleOp, kPrintF16,
134  IntegerType::get(moduleOp->getContext(), 16), // bits!
135  LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
136 }
137 
138 FailureOr<LLVM::LLVMFuncOp>
140  SymbolTableCollection *symbolTables) {
142  b, moduleOp, kPrintBF16,
143  IntegerType::get(moduleOp->getContext(), 16), // bits!
144  LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
145 }
146 
147 FailureOr<LLVM::LLVMFuncOp>
149  SymbolTableCollection *symbolTables) {
151  b, moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
152  LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
153 }
154 
155 FailureOr<LLVM::LLVMFuncOp>
157  SymbolTableCollection *symbolTables) {
159  b, moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
160  LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
161 }
162 
163 static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
164  return LLVM::LLVMPointerType::get(context);
165 }
166 
167 static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
168  // A char pointer and void ptr are the same in LLVM IR.
169  return getCharPtr(context);
170 }
171 
172 FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreatePrintStringFn(
173  OpBuilder &b, Operation *moduleOp,
174  std::optional<StringRef> runtimeFunctionName,
175  SymbolTableCollection *symbolTables) {
177  b, moduleOp, runtimeFunctionName.value_or(kPrintString),
178  getCharPtr(moduleOp->getContext()),
179  LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
180 }
181 
182 FailureOr<LLVM::LLVMFuncOp>
184  SymbolTableCollection *symbolTables) {
186  b, moduleOp, kPrintOpen, {},
187  LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
188 }
189 
190 FailureOr<LLVM::LLVMFuncOp>
192  SymbolTableCollection *symbolTables) {
194  b, moduleOp, kPrintClose, {},
195  LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
196 }
197 
198 FailureOr<LLVM::LLVMFuncOp>
200  SymbolTableCollection *symbolTables) {
202  b, moduleOp, kPrintComma, {},
203  LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
204 }
205 
206 FailureOr<LLVM::LLVMFuncOp>
208  SymbolTableCollection *symbolTables) {
210  b, moduleOp, kPrintNewline, {},
211  LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
212 }
213 
214 FailureOr<LLVM::LLVMFuncOp>
216  Type indexType,
217  SymbolTableCollection *symbolTables) {
218  return lookupOrCreateReservedFn(b, moduleOp, kMalloc, indexType,
219  getVoidPtr(moduleOp->getContext()),
220  symbolTables);
221 }
222 
223 FailureOr<LLVM::LLVMFuncOp>
225  Type indexType,
226  SymbolTableCollection *symbolTables) {
228  b, moduleOp, kAlignedAlloc, {indexType, indexType},
229  getVoidPtr(moduleOp->getContext()), symbolTables);
230 }
231 
232 FailureOr<LLVM::LLVMFuncOp>
234  SymbolTableCollection *symbolTables) {
236  b, moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
237  LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
238 }
239 
240 FailureOr<LLVM::LLVMFuncOp>
242  Type indexType,
243  SymbolTableCollection *symbolTables) {
244  return lookupOrCreateReservedFn(b, moduleOp, kGenericAlloc, indexType,
245  getVoidPtr(moduleOp->getContext()),
246  symbolTables);
247 }
248 
250  OpBuilder &b, Operation *moduleOp, Type indexType,
251  SymbolTableCollection *symbolTables) {
253  b, moduleOp, kGenericAlignedAlloc, {indexType, indexType},
254  getVoidPtr(moduleOp->getContext()), symbolTables);
255 }
256 
257 FailureOr<LLVM::LLVMFuncOp>
259  SymbolTableCollection *symbolTables) {
261  b, moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
262  LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
263 }
264 
265 FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateMemRefCopyFn(
266  OpBuilder &b, Operation *moduleOp, Type indexType,
267  Type unrankedDescriptorType, SymbolTableCollection *symbolTables) {
269  b, moduleOp, kMemRefCopy,
270  ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
271  LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
272 }
static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context)
static constexpr llvm::StringRef kPrintI64
Helper functions to lookup or create the declaration for commonly used external C function calls.
static constexpr llvm::StringRef kFree
static FailureOr< LLVM::LLVMFuncOp > lookupOrCreateReservedFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes, Type resultType, SymbolTableCollection *symbolTables)
static constexpr llvm::StringRef kPrintU64
static constexpr llvm::StringRef kPrintBF16
static constexpr llvm::StringRef kPrintString
static constexpr llvm::StringRef kGenericAlignedAlloc
static constexpr llvm::StringRef kAlignedAlloc
static constexpr llvm::StringRef kPrintClose
static constexpr llvm::StringRef kMalloc
static constexpr llvm::StringRef kMemRefCopy
static constexpr llvm::StringRef kPrintComma
static constexpr llvm::StringRef kGenericAlloc
static constexpr llvm::StringRef kPrintNewline
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context)
static constexpr llvm::StringRef kPrintOpen
static constexpr llvm::StringRef kGenericFree
static constexpr llvm::StringRef kPrintF16
static constexpr llvm::StringRef kPrintF32
static constexpr llvm::StringRef kPrintF64
iterator begin()
Definition: Block.h:143
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType, Type unrankedDescriptorType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Helper functions to look up or create the declaration for commonly used external C function calls.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintStringFn(OpBuilder &b, Operation *moduleOp, std::optional< StringRef > runtimeFunctionName={}, SymbolTableCollection *symbolTables=nullptr)
Declares a function to print a C-string.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes)and namename`.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...