MLIR  22.0.0git
MemRefToEmitC.cpp
Go to the documentation of this file.
1 //===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert memref ops into emitc ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeRange.h"
23 #include "mlir/IR/Value.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include <cstdint>
27 #include <numeric>
28 
29 using namespace mlir;
30 
31 static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) {
32  return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
33  memRefType.getRank() != 0 &&
34  !llvm::is_contained(memRefType.getShape(), 0);
35 }
36 
37 namespace {
38 /// Implement the interface to convert MemRef to EmitC.
39 struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
41 
42  /// Hook for derived dialect interface to provide conversion patterns
43  /// and mark dialect legal for the conversion target.
44  void populateConvertToEmitCConversionPatterns(
45  ConversionTarget &target, TypeConverter &typeConverter,
46  RewritePatternSet &patterns) const final {
49  }
50 };
51 } // namespace
52 
54  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
55  dialect->addInterfaces<MemRefToEmitCDialectInterface>();
56  });
57 }
58 
59 //===----------------------------------------------------------------------===//
60 // Conversion Patterns
61 //===----------------------------------------------------------------------===//
62 
63 namespace {
64 struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
66 
67  LogicalResult
68  matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
69  ConversionPatternRewriter &rewriter) const override {
70 
71  if (!op.getType().hasStaticShape()) {
72  return rewriter.notifyMatchFailure(
73  op.getLoc(), "cannot transform alloca with dynamic shape");
74  }
75 
76  if (op.getAlignment().value_or(1) > 1) {
77  // TODO: Allow alignment if it is not more than the natural alignment
78  // of the C array.
79  return rewriter.notifyMatchFailure(
80  op.getLoc(), "cannot transform alloca with alignment requirement");
81  }
82 
83  auto resultTy = getTypeConverter()->convertType(op.getType());
84  if (!resultTy) {
85  return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
86  }
87  auto noInit = emitc::OpaqueAttr::get(getContext(), "");
88  rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
89  return success();
90  }
91 };
92 
93 Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
94  Type resultTy;
95  if (opTy.getRank() == 0) {
96  resultTy = typeConverter->convertType(mlir::getElementTypeOrSelf(opTy));
97  } else {
98  resultTy = typeConverter->convertType(opTy);
99  }
100  return resultTy;
101 }
102 
103 static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
104  OpBuilder &builder) {
105  assert(isMemRefTypeLegalForEmitC(memrefType) &&
106  "incompatible memref type for EmitC conversion");
107  emitc::CallOpaqueOp elementSize = emitc::CallOpaqueOp::create(
108  builder, loc, emitc::SizeTType::get(builder.getContext()),
109  builder.getStringAttr("sizeof"), ValueRange{},
110  ArrayAttr::get(builder.getContext(),
111  {TypeAttr::get(memrefType.getElementType())}));
112 
113  IndexType indexType = builder.getIndexType();
114  int64_t numElements = llvm::product_of(memrefType.getShape());
115  emitc::ConstantOp numElementsValue = emitc::ConstantOp::create(
116  builder, loc, indexType, builder.getIndexAttr(numElements));
117 
118  Type sizeTType = emitc::SizeTType::get(builder.getContext());
119  emitc::MulOp totalSizeBytes = emitc::MulOp::create(
120  builder, loc, sizeTType, elementSize.getResult(0), numElementsValue);
121 
122  return totalSizeBytes.getResult();
123 }
124 
125 static emitc::ApplyOp
126 createPointerFromEmitcArray(Location loc, OpBuilder &builder,
127  TypedValue<emitc::ArrayType> arrayValue) {
128 
129  emitc::ConstantOp zeroIndex = emitc::ConstantOp::create(
130  builder, loc, builder.getIndexType(), builder.getIndexAttr(0));
131 
132  emitc::ArrayType arrayType = arrayValue.getType();
133  llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex);
134  emitc::SubscriptOp subPtr =
135  emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices));
136  emitc::ApplyOp ptr = emitc::ApplyOp::create(
137  builder, loc, emitc::PointerType::get(arrayType.getElementType()),
138  builder.getStringAttr("&"), subPtr);
139 
140  return ptr;
141 }
142 
143 struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
145  LogicalResult
146  matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
147  ConversionPatternRewriter &rewriter) const override {
148  Location loc = allocOp.getLoc();
149  MemRefType memrefType = allocOp.getType();
150  if (!isMemRefTypeLegalForEmitC(memrefType)) {
151  return rewriter.notifyMatchFailure(
152  loc, "incompatible memref type for EmitC conversion");
153  }
154 
155  Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
156  Type elementType = memrefType.getElementType();
157  IndexType indexType = rewriter.getIndexType();
158  emitc::CallOpaqueOp sizeofElementOp = emitc::CallOpaqueOp::create(
159  rewriter, loc, sizeTType, rewriter.getStringAttr("sizeof"),
160  ValueRange{},
161  ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)}));
162 
163  int64_t numElements = 1;
164  for (int64_t dimSize : memrefType.getShape()) {
165  numElements *= dimSize;
166  }
167  Value numElementsValue = emitc::ConstantOp::create(
168  rewriter, loc, indexType, rewriter.getIndexAttr(numElements));
169 
170  Value totalSizeBytes =
171  emitc::MulOp::create(rewriter, loc, sizeTType,
172  sizeofElementOp.getResult(0), numElementsValue);
173 
174  emitc::CallOpaqueOp allocCall;
175  StringAttr allocFunctionName;
176  Value alignmentValue;
177  SmallVector<Value, 2> argsVec;
178  if (allocOp.getAlignment()) {
179  allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName);
180  alignmentValue = emitc::ConstantOp::create(
181  rewriter, loc, sizeTType,
182  rewriter.getIntegerAttr(indexType,
183  allocOp.getAlignment().value_or(0)));
184  argsVec.push_back(alignmentValue);
185  } else {
186  allocFunctionName = rewriter.getStringAttr(mallocFunctionName);
187  }
188 
189  argsVec.push_back(totalSizeBytes);
190  ValueRange args(argsVec);
191 
192  allocCall = emitc::CallOpaqueOp::create(
193  rewriter, loc,
195  emitc::OpaqueType::get(rewriter.getContext(), "void")),
196  allocFunctionName, args);
197 
198  emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
199  emitc::CastOp castOp = emitc::CastOp::create(
200  rewriter, loc, targetPointerType, allocCall.getResult(0));
201 
202  rewriter.replaceOp(allocOp, castOp);
203  return success();
204  }
205 };
206 
207 struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
209 
210  LogicalResult
211  matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
212  ConversionPatternRewriter &rewriter) const override {
213  Location loc = copyOp.getLoc();
214  MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType());
215  MemRefType targetMemrefType =
216  cast<MemRefType>(copyOp.getTarget().getType());
217 
218  if (!isMemRefTypeLegalForEmitC(srcMemrefType))
219  return rewriter.notifyMatchFailure(
220  loc, "incompatible source memref type for EmitC conversion");
221 
222  if (!isMemRefTypeLegalForEmitC(targetMemrefType))
223  return rewriter.notifyMatchFailure(
224  loc, "incompatible target memref type for EmitC conversion");
225 
226  auto srcArrayValue =
227  cast<TypedValue<emitc::ArrayType>>(operands.getSource());
228  emitc::ApplyOp srcPtr =
229  createPointerFromEmitcArray(loc, rewriter, srcArrayValue);
230 
231  auto targetArrayValue =
232  cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
233  emitc::ApplyOp targetPtr =
234  createPointerFromEmitcArray(loc, rewriter, targetArrayValue);
235 
236  emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create(
237  rewriter, loc, TypeRange{}, "memcpy",
238  ValueRange{
239  targetPtr.getResult(), srcPtr.getResult(),
240  calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
241 
242  rewriter.replaceOp(copyOp, memCpyCall.getResults());
243 
244  return success();
245  }
246 };
247 
248 struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
250 
251  LogicalResult
252  matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
253  ConversionPatternRewriter &rewriter) const override {
254  MemRefType opTy = op.getType();
255  if (!op.getType().hasStaticShape()) {
256  return rewriter.notifyMatchFailure(
257  op.getLoc(), "cannot transform global with dynamic shape");
258  }
259 
260  if (op.getAlignment().value_or(1) > 1) {
261  // TODO: Extend GlobalOp to specify alignment via the `alignas` specifier.
262  return rewriter.notifyMatchFailure(
263  op.getLoc(), "global variable with alignment requirement is "
264  "currently not supported");
265  }
266 
267  Type resultTy = convertMemRefType(opTy, getTypeConverter());
268 
269  if (!resultTy) {
270  return rewriter.notifyMatchFailure(op.getLoc(),
271  "cannot convert result type");
272  }
273 
275  if (visibility != SymbolTable::Visibility::Public &&
276  visibility != SymbolTable::Visibility::Private) {
277  return rewriter.notifyMatchFailure(
278  op.getLoc(),
279  "only public and private visibility is currently supported");
280  }
281  // We are explicit in specifing the linkage because the default linkage
282  // for constants is different in C and C++.
283  bool staticSpecifier = visibility == SymbolTable::Visibility::Private;
284  bool externSpecifier = !staticSpecifier;
285 
286  Attribute initialValue = operands.getInitialValueAttr();
287  if (opTy.getRank() == 0) {
288  auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
289  initialValue = elementsAttr.getSplatValue<Attribute>();
290  }
291  if (isa_and_present<UnitAttr>(initialValue))
292  initialValue = {};
293 
294  rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
295  op, operands.getSymName(), resultTy, initialValue, externSpecifier,
296  staticSpecifier, operands.getConstant());
297  return success();
298  }
299 };
300 
301 struct ConvertGetGlobal final
302  : public OpConversionPattern<memref::GetGlobalOp> {
304 
305  LogicalResult
306  matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
307  ConversionPatternRewriter &rewriter) const override {
308 
309  MemRefType opTy = op.getType();
310  Type resultTy = convertMemRefType(opTy, getTypeConverter());
311 
312  if (!resultTy) {
313  return rewriter.notifyMatchFailure(op.getLoc(),
314  "cannot convert result type");
315  }
316 
317  if (opTy.getRank() == 0) {
318  emitc::LValueType lvalueType = emitc::LValueType::get(resultTy);
319  emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create(
320  rewriter, op.getLoc(), lvalueType, operands.getNameAttr());
321  emitc::PointerType pointerType = emitc::PointerType::get(resultTy);
322  rewriter.replaceOpWithNewOp<emitc::ApplyOp>(
323  op, pointerType, rewriter.getStringAttr("&"), globalLValue);
324  return success();
325  }
326  rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
327  operands.getNameAttr());
328  return success();
329  }
330 };
331 
332 struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
334 
335  LogicalResult
336  matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
337  ConversionPatternRewriter &rewriter) const override {
338 
339  auto resultTy = getTypeConverter()->convertType(op.getType());
340  if (!resultTy) {
341  return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
342  }
343 
344  auto arrayValue =
345  dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
346  if (!arrayValue) {
347  return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
348  }
349 
350  auto subscript = emitc::SubscriptOp::create(
351  rewriter, op.getLoc(), arrayValue, operands.getIndices());
352 
353  rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
354  return success();
355  }
356 };
357 
358 struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
360 
361  LogicalResult
362  matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
363  ConversionPatternRewriter &rewriter) const override {
364  auto arrayValue =
365  dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
366  if (!arrayValue) {
367  return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
368  }
369 
370  auto subscript = emitc::SubscriptOp::create(
371  rewriter, op.getLoc(), arrayValue, operands.getIndices());
372  rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
373  operands.getValue());
374  return success();
375  }
376 };
377 } // namespace
378 
380  typeConverter.addConversion(
381  [&](MemRefType memRefType) -> std::optional<Type> {
382  if (!isMemRefTypeLegalForEmitC(memRefType)) {
383  return {};
384  }
385  Type convertedElementType =
386  typeConverter.convertType(memRefType.getElementType());
387  if (!convertedElementType)
388  return {};
389  return emitc::ArrayType::get(memRefType.getShape(),
390  convertedElementType);
391  });
392 
393  auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType,
394  ValueRange inputs,
395  Location loc) -> Value {
396  if (inputs.size() != 1)
397  return Value();
398 
399  return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
400  .getResult(0);
401  };
402 
403  typeConverter.addSourceMaterialization(materializeAsUnrealizedCast);
404  typeConverter.addTargetMaterialization(materializeAsUnrealizedCast);
405 }
406 
408  RewritePatternSet &patterns, const TypeConverter &converter) {
409  patterns.add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertGlobal,
410  ConvertGetGlobal, ConvertLoad, ConvertStore>(
411  converter, patterns.getContext());
412 }
static MLIRContext * getContext(OpFoldResult val)
static bool isMemRefTypeLegalForEmitC(MemRefType memRefType)
constexpr const char * mallocFunctionName
Definition: MemRefToEmitC.h:12
constexpr const char * alignedAllocFunctionName
Definition: MemRefToEmitC.h:11
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:108
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:228
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:262
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:51
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
ConvertToEmitCPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
Visibility
An enumeration detailing the different visibility types that a symbol may have.
Definition: SymbolTable.h:90
@ Public
The symbol is public and may be referenced anywhere internal or external to the visible references in...
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, const TypeConverter &converter)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConvertMemRefToEmitCInterface(DialectRegistry &registry)