MLIR  22.0.0git
MemRefToEmitC.cpp
Go to the documentation of this file.
1 //===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert memref ops into emitc ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeRange.h"
23 #include "mlir/IR/Value.h"
25 #include <cstdint>
26 #include <numeric>
27 
28 using namespace mlir;
29 
30 static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) {
31  return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
32  memRefType.getRank() != 0 &&
33  !llvm::is_contained(memRefType.getShape(), 0);
34 }
35 
36 namespace {
37 /// Implement the interface to convert MemRef to EmitC.
38 struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
40 
41  /// Hook for derived dialect interface to provide conversion patterns
42  /// and mark dialect legal for the conversion target.
43  void populateConvertToEmitCConversionPatterns(
44  ConversionTarget &target, TypeConverter &typeConverter,
45  RewritePatternSet &patterns) const final {
48  }
49 };
50 } // namespace
51 
53  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
54  dialect->addInterfaces<MemRefToEmitCDialectInterface>();
55  });
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // Conversion Patterns
60 //===----------------------------------------------------------------------===//
61 
62 namespace {
63 struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
65 
66  LogicalResult
67  matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
68  ConversionPatternRewriter &rewriter) const override {
69 
70  if (!op.getType().hasStaticShape()) {
71  return rewriter.notifyMatchFailure(
72  op.getLoc(), "cannot transform alloca with dynamic shape");
73  }
74 
75  if (op.getAlignment().value_or(1) > 1) {
76  // TODO: Allow alignment if it is not more than the natural alignment
77  // of the C array.
78  return rewriter.notifyMatchFailure(
79  op.getLoc(), "cannot transform alloca with alignment requirement");
80  }
81 
82  auto resultTy = getTypeConverter()->convertType(op.getType());
83  if (!resultTy) {
84  return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
85  }
86  auto noInit = emitc::OpaqueAttr::get(getContext(), "");
87  rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
88  return success();
89  }
90 };
91 
92 Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
93  Type resultTy;
94  if (opTy.getRank() == 0) {
95  resultTy = typeConverter->convertType(mlir::getElementTypeOrSelf(opTy));
96  } else {
97  resultTy = typeConverter->convertType(opTy);
98  }
99  return resultTy;
100 }
101 
102 static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
103  OpBuilder &builder) {
104  assert(isMemRefTypeLegalForEmitC(memrefType) &&
105  "incompatible memref type for EmitC conversion");
106  emitc::CallOpaqueOp elementSize = emitc::CallOpaqueOp::create(
107  builder, loc, emitc::SizeTType::get(builder.getContext()),
108  builder.getStringAttr("sizeof"), ValueRange{},
109  ArrayAttr::get(builder.getContext(),
110  {TypeAttr::get(memrefType.getElementType())}));
111 
112  IndexType indexType = builder.getIndexType();
113  int64_t numElements = std::accumulate(memrefType.getShape().begin(),
114  memrefType.getShape().end(), int64_t{1},
115  std::multiplies<int64_t>());
116  emitc::ConstantOp numElementsValue = emitc::ConstantOp::create(
117  builder, loc, indexType, builder.getIndexAttr(numElements));
118 
119  Type sizeTType = emitc::SizeTType::get(builder.getContext());
120  emitc::MulOp totalSizeBytes = emitc::MulOp::create(
121  builder, loc, sizeTType, elementSize.getResult(0), numElementsValue);
122 
123  return totalSizeBytes.getResult();
124 }
125 
126 static emitc::ApplyOp
127 createPointerFromEmitcArray(Location loc, OpBuilder &builder,
128  TypedValue<emitc::ArrayType> arrayValue) {
129 
130  emitc::ConstantOp zeroIndex = emitc::ConstantOp::create(
131  builder, loc, builder.getIndexType(), builder.getIndexAttr(0));
132 
133  emitc::ArrayType arrayType = arrayValue.getType();
134  llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex);
135  emitc::SubscriptOp subPtr =
136  emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices));
137  emitc::ApplyOp ptr = emitc::ApplyOp::create(
138  builder, loc, emitc::PointerType::get(arrayType.getElementType()),
139  builder.getStringAttr("&"), subPtr);
140 
141  return ptr;
142 }
143 
144 struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
146  LogicalResult
147  matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
148  ConversionPatternRewriter &rewriter) const override {
149  Location loc = allocOp.getLoc();
150  MemRefType memrefType = allocOp.getType();
151  if (!isMemRefTypeLegalForEmitC(memrefType)) {
152  return rewriter.notifyMatchFailure(
153  loc, "incompatible memref type for EmitC conversion");
154  }
155 
156  Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
157  Type elementType = memrefType.getElementType();
158  IndexType indexType = rewriter.getIndexType();
159  emitc::CallOpaqueOp sizeofElementOp = emitc::CallOpaqueOp::create(
160  rewriter, loc, sizeTType, rewriter.getStringAttr("sizeof"),
161  ValueRange{},
162  ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)}));
163 
164  int64_t numElements = 1;
165  for (int64_t dimSize : memrefType.getShape()) {
166  numElements *= dimSize;
167  }
168  Value numElementsValue = emitc::ConstantOp::create(
169  rewriter, loc, indexType, rewriter.getIndexAttr(numElements));
170 
171  Value totalSizeBytes =
172  emitc::MulOp::create(rewriter, loc, sizeTType,
173  sizeofElementOp.getResult(0), numElementsValue);
174 
175  emitc::CallOpaqueOp allocCall;
176  StringAttr allocFunctionName;
177  Value alignmentValue;
178  SmallVector<Value, 2> argsVec;
179  if (allocOp.getAlignment()) {
180  allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName);
181  alignmentValue = emitc::ConstantOp::create(
182  rewriter, loc, sizeTType,
183  rewriter.getIntegerAttr(indexType,
184  allocOp.getAlignment().value_or(0)));
185  argsVec.push_back(alignmentValue);
186  } else {
187  allocFunctionName = rewriter.getStringAttr(mallocFunctionName);
188  }
189 
190  argsVec.push_back(totalSizeBytes);
191  ValueRange args(argsVec);
192 
193  allocCall = emitc::CallOpaqueOp::create(
194  rewriter, loc,
196  emitc::OpaqueType::get(rewriter.getContext(), "void")),
197  allocFunctionName, args);
198 
199  emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
200  emitc::CastOp castOp = emitc::CastOp::create(
201  rewriter, loc, targetPointerType, allocCall.getResult(0));
202 
203  rewriter.replaceOp(allocOp, castOp);
204  return success();
205  }
206 };
207 
208 struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
210 
211  LogicalResult
212  matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
213  ConversionPatternRewriter &rewriter) const override {
214  Location loc = copyOp.getLoc();
215  MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType());
216  MemRefType targetMemrefType =
217  cast<MemRefType>(copyOp.getTarget().getType());
218 
219  if (!isMemRefTypeLegalForEmitC(srcMemrefType))
220  return rewriter.notifyMatchFailure(
221  loc, "incompatible source memref type for EmitC conversion");
222 
223  if (!isMemRefTypeLegalForEmitC(targetMemrefType))
224  return rewriter.notifyMatchFailure(
225  loc, "incompatible target memref type for EmitC conversion");
226 
227  auto srcArrayValue =
228  cast<TypedValue<emitc::ArrayType>>(operands.getSource());
229  emitc::ApplyOp srcPtr =
230  createPointerFromEmitcArray(loc, rewriter, srcArrayValue);
231 
232  auto targetArrayValue =
233  cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
234  emitc::ApplyOp targetPtr =
235  createPointerFromEmitcArray(loc, rewriter, targetArrayValue);
236 
237  emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create(
238  rewriter, loc, TypeRange{}, "memcpy",
239  ValueRange{
240  targetPtr.getResult(), srcPtr.getResult(),
241  calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
242 
243  rewriter.replaceOp(copyOp, memCpyCall.getResults());
244 
245  return success();
246  }
247 };
248 
249 struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
251 
252  LogicalResult
253  matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
254  ConversionPatternRewriter &rewriter) const override {
255  MemRefType opTy = op.getType();
256  if (!op.getType().hasStaticShape()) {
257  return rewriter.notifyMatchFailure(
258  op.getLoc(), "cannot transform global with dynamic shape");
259  }
260 
261  if (op.getAlignment().value_or(1) > 1) {
262  // TODO: Extend GlobalOp to specify alignment via the `alignas` specifier.
263  return rewriter.notifyMatchFailure(
264  op.getLoc(), "global variable with alignment requirement is "
265  "currently not supported");
266  }
267 
268  Type resultTy = convertMemRefType(opTy, getTypeConverter());
269 
270  if (!resultTy) {
271  return rewriter.notifyMatchFailure(op.getLoc(),
272  "cannot convert result type");
273  }
274 
276  if (visibility != SymbolTable::Visibility::Public &&
277  visibility != SymbolTable::Visibility::Private) {
278  return rewriter.notifyMatchFailure(
279  op.getLoc(),
280  "only public and private visibility is currently supported");
281  }
282  // We are explicit in specifing the linkage because the default linkage
283  // for constants is different in C and C++.
284  bool staticSpecifier = visibility == SymbolTable::Visibility::Private;
285  bool externSpecifier = !staticSpecifier;
286 
287  Attribute initialValue = operands.getInitialValueAttr();
288  if (opTy.getRank() == 0) {
289  auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
290  initialValue = elementsAttr.getSplatValue<Attribute>();
291  }
292  if (isa_and_present<UnitAttr>(initialValue))
293  initialValue = {};
294 
295  rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
296  op, operands.getSymName(), resultTy, initialValue, externSpecifier,
297  staticSpecifier, operands.getConstant());
298  return success();
299  }
300 };
301 
302 struct ConvertGetGlobal final
303  : public OpConversionPattern<memref::GetGlobalOp> {
305 
306  LogicalResult
307  matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
308  ConversionPatternRewriter &rewriter) const override {
309 
310  MemRefType opTy = op.getType();
311  Type resultTy = convertMemRefType(opTy, getTypeConverter());
312 
313  if (!resultTy) {
314  return rewriter.notifyMatchFailure(op.getLoc(),
315  "cannot convert result type");
316  }
317 
318  if (opTy.getRank() == 0) {
319  emitc::LValueType lvalueType = emitc::LValueType::get(resultTy);
320  emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create(
321  rewriter, op.getLoc(), lvalueType, operands.getNameAttr());
322  emitc::PointerType pointerType = emitc::PointerType::get(resultTy);
323  rewriter.replaceOpWithNewOp<emitc::ApplyOp>(
324  op, pointerType, rewriter.getStringAttr("&"), globalLValue);
325  return success();
326  }
327  rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
328  operands.getNameAttr());
329  return success();
330  }
331 };
332 
333 struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
335 
336  LogicalResult
337  matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
338  ConversionPatternRewriter &rewriter) const override {
339 
340  auto resultTy = getTypeConverter()->convertType(op.getType());
341  if (!resultTy) {
342  return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
343  }
344 
345  auto arrayValue =
346  dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
347  if (!arrayValue) {
348  return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
349  }
350 
351  auto subscript = emitc::SubscriptOp::create(
352  rewriter, op.getLoc(), arrayValue, operands.getIndices());
353 
354  rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
355  return success();
356  }
357 };
358 
359 struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
361 
362  LogicalResult
363  matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
364  ConversionPatternRewriter &rewriter) const override {
365  auto arrayValue =
366  dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
367  if (!arrayValue) {
368  return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
369  }
370 
371  auto subscript = emitc::SubscriptOp::create(
372  rewriter, op.getLoc(), arrayValue, operands.getIndices());
373  rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
374  operands.getValue());
375  return success();
376  }
377 };
378 } // namespace
379 
381  typeConverter.addConversion(
382  [&](MemRefType memRefType) -> std::optional<Type> {
383  if (!isMemRefTypeLegalForEmitC(memRefType)) {
384  return {};
385  }
386  Type convertedElementType =
387  typeConverter.convertType(memRefType.getElementType());
388  if (!convertedElementType)
389  return {};
390  return emitc::ArrayType::get(memRefType.getShape(),
391  convertedElementType);
392  });
393 
394  auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType,
395  ValueRange inputs,
396  Location loc) -> Value {
397  if (inputs.size() != 1)
398  return Value();
399 
400  return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
401  .getResult(0);
402  };
403 
404  typeConverter.addSourceMaterialization(materializeAsUnrealizedCast);
405  typeConverter.addTargetMaterialization(materializeAsUnrealizedCast);
406 }
407 
409  RewritePatternSet &patterns, const TypeConverter &converter) {
410  patterns.add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertGlobal,
411  ConvertGetGlobal, ConvertLoad, ConvertStore>(
412  converter, patterns.getContext());
413 }
static MLIRContext * getContext(OpFoldResult val)
static bool isMemRefTypeLegalForEmitC(MemRefType memRefType)
constexpr const char * mallocFunctionName
Definition: MemRefToEmitC.h:12
constexpr const char * alignedAllocFunctionName
Definition: MemRefToEmitC.h:11
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:261
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:50
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
ConvertToEmitCPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
Visibility
An enumeration detailing the different visibility types that a symbol may have.
Definition: SymbolTable.h:90
@ Public
The symbol is public and may be referenced anywhere internal or external to the visible references in...
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, const TypeConverter &converter)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConvertMemRefToEmitCInterface(DialectRegistry &registry)