MLIR 23.0.0git
FunctionCallUtils.cpp
Go to the documentation of this file.
1//===- FunctionCallUtils.cpp - Utilities for C function calls -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements helper functions to call common simple C functions in
10// LLVMIR (e.g. amon others to support printing and debugging).
11//
12//===----------------------------------------------------------------------===//
13
16#include "mlir/IR/Builders.h"
18#include "mlir/Support/LLVM.h"
19
20using namespace mlir;
21using namespace mlir::LLVM;
22
23/// Helper functions to lookup or create the declaration for commonly used
24/// external C function calls. The list of functions provided here must be
25/// implemented separately (e.g. as part of a support runtime library or as
26/// part of the libc).
27static constexpr llvm::StringRef kPrintI64 = "printI64";
28static constexpr llvm::StringRef kPrintU64 = "printU64";
29static constexpr llvm::StringRef kPrintF16 = "printF16";
30static constexpr llvm::StringRef kPrintBF16 = "printBF16";
31static constexpr llvm::StringRef kPrintF32 = "printF32";
32static constexpr llvm::StringRef kPrintF64 = "printF64";
33static constexpr llvm::StringRef kPrintApFloat = "printApFloat";
34static constexpr llvm::StringRef kPrintString = "printString";
35static constexpr llvm::StringRef kPrintOpen = "printOpen";
36static constexpr llvm::StringRef kPrintClose = "printClose";
37static constexpr llvm::StringRef kPrintComma = "printComma";
38static constexpr llvm::StringRef kPrintNewline = "printNewline";
39static constexpr llvm::StringRef kMalloc = "malloc";
40static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc";
41static constexpr llvm::StringRef kFree = "free";
42static constexpr llvm::StringRef kGenericAlloc = "_mlir_memref_to_llvm_alloc";
43static constexpr llvm::StringRef kGenericAlignedAlloc =
44 "_mlir_memref_to_llvm_aligned_alloc";
45static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
46static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
47
48namespace {
49/// Search for an LLVMFuncOp with a given name within an operation with the
50/// SymbolTable trait. An optional collection of cached symbol tables can be
51/// given to avoid a linear scan of the symbol table operation.
52LLVM::LLVMFuncOp lookupFuncOp(StringRef name, Operation *symbolTableOp,
53 SymbolTableCollection *symbolTables = nullptr) {
54 if (symbolTables) {
55 return symbolTables->lookupSymbolIn<LLVM::LLVMFuncOp>(
56 symbolTableOp, StringAttr::get(symbolTableOp->getContext(), name));
57 }
58
59 return llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
60 SymbolTable::lookupSymbolIn(symbolTableOp, name));
61}
62} // namespace
63
64/// Generic print function lookupOrCreate helper.
65FailureOr<LLVM::LLVMFuncOp>
67 ArrayRef<Type> paramTypes, Type resultType,
68 bool isVarArg, bool isReserved,
69 SymbolTableCollection *symbolTables) {
70 assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
71 "expected SymbolTable operation");
72 auto func = lookupFuncOp(name, moduleOp, symbolTables);
73 auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg);
74 // Assert the signature of the found function is same as expected
75 if (func) {
76 if (funcT != func.getFunctionType()) {
77 if (isReserved) {
78 func.emitError("redefinition of reserved function '")
79 << name << "' of different type " << func.getFunctionType()
80 << " is prohibited";
81 } else {
82 func.emitError("redefinition of function '")
83 << name << "' of different type " << funcT << " is prohibited";
84 }
85 return failure();
86 }
87 return func;
88 }
89
90 // A symbol with this name may already exist as a non-LLVM function (e.g.,
91 // func::FuncOp from user code that hasn't been converted to LLVM dialect
92 // yet). Creating a new LLVMFuncOp with the same name would cause a symbol
93 // redefinition error. Return failure so the calling pattern can retry after
94 // the existing symbol is converted.
95 if (symbolTables
96 ? symbolTables->lookupSymbolIn(
97 moduleOp, StringAttr::get(moduleOp->getContext(), name))
98 : SymbolTable::lookupSymbolIn(moduleOp, name))
99 return failure();
100
102 assert(!moduleOp->getRegion(0).empty() && "expected non-empty region");
103 b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
104 auto funcOp = LLVM::LLVMFuncOp::create(
105 b, moduleOp->getLoc(), name,
106 LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
107
108 if (symbolTables) {
109 SymbolTable &symbolTable = symbolTables->getSymbolTable(moduleOp);
110 symbolTable.insert(funcOp, moduleOp->getRegion(0).front().begin());
111 }
112
113 return funcOp;
114}
115
116static FailureOr<LLVM::LLVMFuncOp>
117lookupOrCreateReservedFn(OpBuilder &b, Operation *moduleOp, StringRef name,
118 ArrayRef<Type> paramTypes, Type resultType,
119 SymbolTableCollection *symbolTables) {
120 return lookupOrCreateFn(b, moduleOp, name, paramTypes, resultType,
121 /*isVarArg=*/false, /*isReserved=*/true,
122 symbolTables);
123}
124
125FailureOr<LLVM::LLVMFuncOp>
127 SymbolTableCollection *symbolTables) {
129 b, moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
130 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
131}
132
133FailureOr<LLVM::LLVMFuncOp>
135 SymbolTableCollection *symbolTables) {
137 b, moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
138 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
139}
140
141FailureOr<LLVM::LLVMFuncOp>
143 SymbolTableCollection *symbolTables) {
145 b, moduleOp, kPrintF16,
146 IntegerType::get(moduleOp->getContext(), 16), // bits!
147 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
148}
149
150FailureOr<LLVM::LLVMFuncOp>
152 SymbolTableCollection *symbolTables) {
154 b, moduleOp, kPrintBF16,
155 IntegerType::get(moduleOp->getContext(), 16), // bits!
156 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
157}
158
159FailureOr<LLVM::LLVMFuncOp>
161 SymbolTableCollection *symbolTables) {
163 b, moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
164 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
165}
166
167FailureOr<LLVM::LLVMFuncOp>
169 SymbolTableCollection *symbolTables) {
171 b, moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
172 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
173}
174
175FailureOr<LLVM::LLVMFuncOp>
177 SymbolTableCollection *symbolTables) {
179 b, moduleOp, kPrintApFloat,
180 {IntegerType::get(moduleOp->getContext(), 32),
181 IntegerType::get(moduleOp->getContext(), 64)},
182 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
183}
184
185static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
186 return LLVM::LLVMPointerType::get(context);
187}
188
189static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
190 // A char pointer and void ptr are the same in LLVM IR.
191 return getCharPtr(context);
192}
193
194FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreatePrintStringFn(
195 OpBuilder &b, Operation *moduleOp,
196 std::optional<StringRef> runtimeFunctionName,
197 SymbolTableCollection *symbolTables) {
199 b, moduleOp, runtimeFunctionName.value_or(kPrintString),
200 getCharPtr(moduleOp->getContext()),
201 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
202}
203
204FailureOr<LLVM::LLVMFuncOp>
206 SymbolTableCollection *symbolTables) {
208 b, moduleOp, kPrintOpen, {},
209 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
210}
211
212FailureOr<LLVM::LLVMFuncOp>
214 SymbolTableCollection *symbolTables) {
216 b, moduleOp, kPrintClose, {},
217 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
218}
219
220FailureOr<LLVM::LLVMFuncOp>
222 SymbolTableCollection *symbolTables) {
224 b, moduleOp, kPrintComma, {},
225 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
226}
227
228FailureOr<LLVM::LLVMFuncOp>
230 SymbolTableCollection *symbolTables) {
232 b, moduleOp, kPrintNewline, {},
233 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
234}
235
236FailureOr<LLVM::LLVMFuncOp>
238 Type indexType,
239 SymbolTableCollection *symbolTables) {
240 return lookupOrCreateReservedFn(b, moduleOp, kMalloc, indexType,
241 getVoidPtr(moduleOp->getContext()),
242 symbolTables);
243}
244
245FailureOr<LLVM::LLVMFuncOp>
247 Type indexType,
248 SymbolTableCollection *symbolTables) {
250 b, moduleOp, kAlignedAlloc, {indexType, indexType},
251 getVoidPtr(moduleOp->getContext()), symbolTables);
252}
253
254FailureOr<LLVM::LLVMFuncOp>
256 SymbolTableCollection *symbolTables) {
258 b, moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
259 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
260}
261
262FailureOr<LLVM::LLVMFuncOp>
264 Type indexType,
265 SymbolTableCollection *symbolTables) {
266 return lookupOrCreateReservedFn(b, moduleOp, kGenericAlloc, indexType,
267 getVoidPtr(moduleOp->getContext()),
268 symbolTables);
269}
270
272 OpBuilder &b, Operation *moduleOp, Type indexType,
273 SymbolTableCollection *symbolTables) {
275 b, moduleOp, kGenericAlignedAlloc, {indexType, indexType},
276 getVoidPtr(moduleOp->getContext()), symbolTables);
277}
278
279FailureOr<LLVM::LLVMFuncOp>
281 SymbolTableCollection *symbolTables) {
283 b, moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
284 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
285}
286
287FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateMemRefCopyFn(
288 OpBuilder &b, Operation *moduleOp, Type indexType,
289 Type unrankedDescriptorType, SymbolTableCollection *symbolTables) {
291 b, moduleOp, kMemRefCopy,
292 ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
293 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
294}
static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context)
static constexpr llvm::StringRef kPrintI64
Helper functions to lookup or create the declaration for commonly used external C function calls.
static constexpr llvm::StringRef kPrintApFloat
static FailureOr< LLVM::LLVMFuncOp > lookupOrCreateReservedFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes, Type resultType, SymbolTableCollection *symbolTables)
static constexpr llvm::StringRef kFree
static constexpr llvm::StringRef kPrintU64
static constexpr llvm::StringRef kPrintBF16
static constexpr llvm::StringRef kPrintString
static constexpr llvm::StringRef kGenericAlignedAlloc
static constexpr llvm::StringRef kAlignedAlloc
static constexpr llvm::StringRef kPrintClose
static constexpr llvm::StringRef kMalloc
static constexpr llvm::StringRef kMemRefCopy
static constexpr llvm::StringRef kPrintComma
static constexpr llvm::StringRef kGenericAlloc
static constexpr llvm::StringRef kPrintNewline
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context)
static constexpr llvm::StringRef kPrintOpen
static constexpr llvm::StringRef kGenericFree
static constexpr llvm::StringRef kPrintF16
static constexpr llvm::StringRef kPrintF32
static constexpr llvm::StringRef kPrintF64
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
iterator begin()
Definition Block.h:153
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:712
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
This class represents a collection of SymbolTables.
virtual Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType, Type unrankedDescriptorType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Helper functions to look up or create the declaration for commonly used external C function calls.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintStringFn(OpBuilder &b, Operation *moduleOp, std::optional< StringRef > runtimeFunctionName={}, SymbolTableCollection *symbolTables=nullptr)
Declares a function to print a C-string.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
Include the generated interface declarations.