MLIR  21.0.0git
FunctionCallUtils.cpp
Go to the documentation of this file.
1 //===- FunctionCallUtils.cpp - Utilities for C function calls -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements helper functions to call common simple C functions in
10 // LLVMIR (e.g. amon others to support printing and debugging).
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/Support/LLVM.h"
19 
20 using namespace mlir;
21 using namespace mlir::LLVM;
22 
23 /// Helper functions to lookup or create the declaration for commonly used
24 /// external C function calls. The list of functions provided here must be
25 /// implemented separately (e.g. as part of a support runtime library or as
26 /// part of the libc).
27 static constexpr llvm::StringRef kPrintI64 = "printI64";
28 static constexpr llvm::StringRef kPrintU64 = "printU64";
29 static constexpr llvm::StringRef kPrintF16 = "printF16";
30 static constexpr llvm::StringRef kPrintBF16 = "printBF16";
31 static constexpr llvm::StringRef kPrintF32 = "printF32";
32 static constexpr llvm::StringRef kPrintF64 = "printF64";
33 static constexpr llvm::StringRef kPrintString = "printString";
34 static constexpr llvm::StringRef kPrintOpen = "printOpen";
35 static constexpr llvm::StringRef kPrintClose = "printClose";
36 static constexpr llvm::StringRef kPrintComma = "printComma";
37 static constexpr llvm::StringRef kPrintNewline = "printNewline";
38 static constexpr llvm::StringRef kMalloc = "malloc";
39 static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc";
40 static constexpr llvm::StringRef kFree = "free";
41 static constexpr llvm::StringRef kGenericAlloc = "_mlir_memref_to_llvm_alloc";
42 static constexpr llvm::StringRef kGenericAlignedAlloc =
43  "_mlir_memref_to_llvm_aligned_alloc";
44 static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
45 static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
46 
47 /// Generic print function lookupOrCreate helper.
48 FailureOr<LLVM::LLVMFuncOp>
49 mlir::LLVM::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
50  ArrayRef<Type> paramTypes, Type resultType,
51  bool isVarArg, bool isReserved) {
52  assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
53  "expected SymbolTable operation");
54  auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
55  SymbolTable::lookupSymbolIn(moduleOp, name));
56  auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg);
57  // Assert the signature of the found function is same as expected
58  if (func) {
59  if (funcT != func.getFunctionType()) {
60  if (isReserved) {
61  func.emitError("redefinition of reserved function '")
62  << name << "' of different type " << func.getFunctionType()
63  << " is prohibited";
64  } else {
65  func.emitError("redefinition of function '")
66  << name << "' of different type " << funcT << " is prohibited";
67  }
68  return failure();
69  }
70  return func;
71  }
72 
74  assert(!moduleOp->getRegion(0).empty() && "expected non-empty region");
75  b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
76  return b.create<LLVM::LLVMFuncOp>(
77  moduleOp->getLoc(), name,
78  LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
79 }
80 
81 static FailureOr<LLVM::LLVMFuncOp>
82 lookupOrCreateReservedFn(OpBuilder &b, Operation *moduleOp, StringRef name,
83  ArrayRef<Type> paramTypes, Type resultType) {
84  return lookupOrCreateFn(b, moduleOp, name, paramTypes, resultType,
85  /*isVarArg=*/false, /*isReserved=*/true);
86 }
87 
88 FailureOr<LLVM::LLVMFuncOp>
91  b, moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
92  LLVM::LLVMVoidType::get(moduleOp->getContext()));
93 }
94 
95 FailureOr<LLVM::LLVMFuncOp>
98  b, moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
99  LLVM::LLVMVoidType::get(moduleOp->getContext()));
100 }
101 
102 FailureOr<LLVM::LLVMFuncOp>
105  b, moduleOp, kPrintF16,
106  IntegerType::get(moduleOp->getContext(), 16), // bits!
107  LLVM::LLVMVoidType::get(moduleOp->getContext()));
108 }
109 
110 FailureOr<LLVM::LLVMFuncOp>
113  b, moduleOp, kPrintBF16,
114  IntegerType::get(moduleOp->getContext(), 16), // bits!
115  LLVM::LLVMVoidType::get(moduleOp->getContext()));
116 }
117 
118 FailureOr<LLVM::LLVMFuncOp>
121  b, moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
122  LLVM::LLVMVoidType::get(moduleOp->getContext()));
123 }
124 
125 FailureOr<LLVM::LLVMFuncOp>
128  b, moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
129  LLVM::LLVMVoidType::get(moduleOp->getContext()));
130 }
131 
132 static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
133  return LLVM::LLVMPointerType::get(context);
134 }
135 
136 static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
137  // A char pointer and void ptr are the same in LLVM IR.
138  return getCharPtr(context);
139 }
140 
141 FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreatePrintStringFn(
142  OpBuilder &b, Operation *moduleOp,
143  std::optional<StringRef> runtimeFunctionName) {
145  b, moduleOp, runtimeFunctionName.value_or(kPrintString),
146  getCharPtr(moduleOp->getContext()),
147  LLVM::LLVMVoidType::get(moduleOp->getContext()));
148 }
149 
150 FailureOr<LLVM::LLVMFuncOp>
153  b, moduleOp, kPrintOpen, {},
154  LLVM::LLVMVoidType::get(moduleOp->getContext()));
155 }
156 
157 FailureOr<LLVM::LLVMFuncOp>
160  b, moduleOp, kPrintClose, {},
161  LLVM::LLVMVoidType::get(moduleOp->getContext()));
162 }
163 
164 FailureOr<LLVM::LLVMFuncOp>
167  b, moduleOp, kPrintComma, {},
168  LLVM::LLVMVoidType::get(moduleOp->getContext()));
169 }
170 
171 FailureOr<LLVM::LLVMFuncOp>
174  b, moduleOp, kPrintNewline, {},
175  LLVM::LLVMVoidType::get(moduleOp->getContext()));
176 }
177 
178 FailureOr<LLVM::LLVMFuncOp>
180  Type indexType) {
181  return lookupOrCreateReservedFn(b, moduleOp, kMalloc, indexType,
182  getVoidPtr(moduleOp->getContext()));
183 }
184 
185 FailureOr<LLVM::LLVMFuncOp>
187  Type indexType) {
188  return lookupOrCreateReservedFn(b, moduleOp, kAlignedAlloc,
189  {indexType, indexType},
190  getVoidPtr(moduleOp->getContext()));
191 }
192 
193 FailureOr<LLVM::LLVMFuncOp>
196  b, moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
197  LLVM::LLVMVoidType::get(moduleOp->getContext()));
198 }
199 
200 FailureOr<LLVM::LLVMFuncOp>
202  Type indexType) {
203  return lookupOrCreateReservedFn(b, moduleOp, kGenericAlloc, indexType,
204  getVoidPtr(moduleOp->getContext()));
205 }
206 
208  OpBuilder &b, Operation *moduleOp, Type indexType) {
210  {indexType, indexType},
211  getVoidPtr(moduleOp->getContext()));
212 }
213 
214 FailureOr<LLVM::LLVMFuncOp>
217  b, moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
218  LLVM::LLVMVoidType::get(moduleOp->getContext()));
219 }
220 
221 FailureOr<LLVM::LLVMFuncOp>
223  Type indexType,
224  Type unrankedDescriptorType) {
226  b, moduleOp, kMemRefCopy,
227  ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
228  LLVM::LLVMVoidType::get(moduleOp->getContext()));
229 }
static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context)
static constexpr llvm::StringRef kPrintI64
Helper functions to lookup or create the declaration for commonly used external C function calls.
static constexpr llvm::StringRef kFree
static constexpr llvm::StringRef kPrintU64
static constexpr llvm::StringRef kPrintBF16
static constexpr llvm::StringRef kPrintString
static constexpr llvm::StringRef kGenericAlignedAlloc
static constexpr llvm::StringRef kAlignedAlloc
static constexpr llvm::StringRef kPrintClose
static constexpr llvm::StringRef kMalloc
static constexpr llvm::StringRef kMemRefCopy
static constexpr llvm::StringRef kPrintComma
static constexpr llvm::StringRef kGenericAlloc
static constexpr llvm::StringRef kPrintNewline
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context)
static constexpr llvm::StringRef kPrintOpen
static constexpr llvm::StringRef kGenericFree
static FailureOr< LLVM::LLVMFuncOp > lookupOrCreateReservedFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes, Type resultType)
static constexpr llvm::StringRef kPrintF16
static constexpr llvm::StringRef kPrintF32
static constexpr llvm::StringRef kPrintF64
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:442
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType, Type unrankedDescriptorType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintStringFn(OpBuilder &b, Operation *moduleOp, std::optional< StringRef > runtimeFunctionName={})
Declares a function to print a C-string.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp)
Helper functions to look up or create the declaration for commonly used external C function calls.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false)
Create a FuncOp with signature resultType(paramTypes)and namename`.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...