MLIR 22.0.0git
MemRefToEmitC.cpp
Go to the documentation of this file.
1//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert memref ops into emitc ops.
10//
11//===----------------------------------------------------------------------===//
12
14
18#include "mlir/IR/Builders.h"
20#include "mlir/IR/Diagnostics.h"
22#include "mlir/IR/TypeRange.h"
23#include "mlir/IR/Value.h"
25#include "llvm/ADT/STLExtras.h"
26#include <cstdint>
27#include <numeric>
28
29using namespace mlir;
30
31static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) {
32 return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
33 memRefType.getRank() != 0 &&
34 !llvm::is_contained(memRefType.getShape(), 0);
35}
36
37namespace {
38/// Implement the interface to convert MemRef to EmitC.
39struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
41
42 /// Hook for derived dialect interface to provide conversion patterns
43 /// and mark dialect legal for the conversion target.
44 void populateConvertToEmitCConversionPatterns(
45 ConversionTarget &target, TypeConverter &typeConverter,
46 RewritePatternSet &patterns) const final {
49 }
50};
51} // namespace
52
54 registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
55 dialect->addInterfaces<MemRefToEmitCDialectInterface>();
56 });
57}
58
59//===----------------------------------------------------------------------===//
60// Conversion Patterns
61//===----------------------------------------------------------------------===//
62
63namespace {
64struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
65 using OpConversionPattern::OpConversionPattern;
66
67 LogicalResult
68 matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
69 ConversionPatternRewriter &rewriter) const override {
70
71 if (!op.getType().hasStaticShape()) {
72 return rewriter.notifyMatchFailure(
73 op.getLoc(), "cannot transform alloca with dynamic shape");
74 }
75
76 if (op.getAlignment().value_or(1) > 1) {
77 // TODO: Allow alignment if it is not more than the natural alignment
78 // of the C array.
79 return rewriter.notifyMatchFailure(
80 op.getLoc(), "cannot transform alloca with alignment requirement");
81 }
82
83 auto resultTy = getTypeConverter()->convertType(op.getType());
84 if (!resultTy) {
85 return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
86 }
87 auto noInit = emitc::OpaqueAttr::get(getContext(), "");
88 rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
89 return success();
90 }
91};
92
93Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
94 Type resultTy;
95 if (opTy.getRank() == 0) {
96 resultTy = typeConverter->convertType(mlir::getElementTypeOrSelf(opTy));
97 } else {
98 resultTy = typeConverter->convertType(opTy);
99 }
100 return resultTy;
101}
102
103static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
104 OpBuilder &builder) {
105 assert(isMemRefTypeLegalForEmitC(memrefType) &&
106 "incompatible memref type for EmitC conversion");
107 emitc::CallOpaqueOp elementSize = emitc::CallOpaqueOp::create(
108 builder, loc, emitc::SizeTType::get(builder.getContext()),
109 builder.getStringAttr("sizeof"), ValueRange{},
110 ArrayAttr::get(builder.getContext(),
111 {TypeAttr::get(memrefType.getElementType())}));
112
113 IndexType indexType = builder.getIndexType();
114 int64_t numElements = llvm::product_of(memrefType.getShape());
115 emitc::ConstantOp numElementsValue = emitc::ConstantOp::create(
116 builder, loc, indexType, builder.getIndexAttr(numElements));
117
118 Type sizeTType = emitc::SizeTType::get(builder.getContext());
119 emitc::MulOp totalSizeBytes = emitc::MulOp::create(
120 builder, loc, sizeTType, elementSize.getResult(0), numElementsValue);
121
122 return totalSizeBytes.getResult();
123}
124
125static emitc::ApplyOp
126createPointerFromEmitcArray(Location loc, OpBuilder &builder,
127 TypedValue<emitc::ArrayType> arrayValue) {
128
129 emitc::ConstantOp zeroIndex = emitc::ConstantOp::create(
130 builder, loc, builder.getIndexType(), builder.getIndexAttr(0));
131
132 emitc::ArrayType arrayType = arrayValue.getType();
133 llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex);
134 emitc::SubscriptOp subPtr =
135 emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices));
136 emitc::ApplyOp ptr = emitc::ApplyOp::create(
137 builder, loc, emitc::PointerType::get(arrayType.getElementType()),
138 builder.getStringAttr("&"), subPtr);
139
140 return ptr;
141}
142
143struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
144 using OpConversionPattern::OpConversionPattern;
145 LogicalResult
146 matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
147 ConversionPatternRewriter &rewriter) const override {
148 Location loc = allocOp.getLoc();
149 MemRefType memrefType = allocOp.getType();
150 if (!isMemRefTypeLegalForEmitC(memrefType)) {
151 return rewriter.notifyMatchFailure(
152 loc, "incompatible memref type for EmitC conversion");
153 }
154
155 Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
156 Type elementType = memrefType.getElementType();
157 IndexType indexType = rewriter.getIndexType();
158 emitc::CallOpaqueOp sizeofElementOp = emitc::CallOpaqueOp::create(
159 rewriter, loc, sizeTType, rewriter.getStringAttr("sizeof"),
160 ValueRange{},
161 ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)}));
162
163 int64_t numElements = 1;
164 for (int64_t dimSize : memrefType.getShape()) {
165 numElements *= dimSize;
166 }
167 Value numElementsValue = emitc::ConstantOp::create(
168 rewriter, loc, indexType, rewriter.getIndexAttr(numElements));
169
170 Value totalSizeBytes =
171 emitc::MulOp::create(rewriter, loc, sizeTType,
172 sizeofElementOp.getResult(0), numElementsValue);
173
174 emitc::CallOpaqueOp allocCall;
175 StringAttr allocFunctionName;
176 Value alignmentValue;
177 SmallVector<Value, 2> argsVec;
178 if (allocOp.getAlignment()) {
179 allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName);
180 alignmentValue = emitc::ConstantOp::create(
181 rewriter, loc, sizeTType,
182 rewriter.getIntegerAttr(indexType,
183 allocOp.getAlignment().value_or(0)));
184 argsVec.push_back(alignmentValue);
185 } else {
186 allocFunctionName = rewriter.getStringAttr(mallocFunctionName);
187 }
188
189 argsVec.push_back(totalSizeBytes);
190 ValueRange args(argsVec);
191
192 allocCall = emitc::CallOpaqueOp::create(
193 rewriter, loc,
194 emitc::PointerType::get(
195 emitc::OpaqueType::get(rewriter.getContext(), "void")),
196 allocFunctionName, args);
197
198 emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
199 emitc::CastOp castOp = emitc::CastOp::create(
200 rewriter, loc, targetPointerType, allocCall.getResult(0));
201
202 rewriter.replaceOp(allocOp, castOp);
203 return success();
204 }
205};
206
207struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
208 using OpConversionPattern::OpConversionPattern;
209
210 LogicalResult
211 matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
212 ConversionPatternRewriter &rewriter) const override {
213 Location loc = copyOp.getLoc();
214 MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType());
215 MemRefType targetMemrefType =
216 cast<MemRefType>(copyOp.getTarget().getType());
217
218 if (!isMemRefTypeLegalForEmitC(srcMemrefType))
219 return rewriter.notifyMatchFailure(
220 loc, "incompatible source memref type for EmitC conversion");
221
222 if (!isMemRefTypeLegalForEmitC(targetMemrefType))
223 return rewriter.notifyMatchFailure(
224 loc, "incompatible target memref type for EmitC conversion");
225
226 auto srcArrayValue =
227 cast<TypedValue<emitc::ArrayType>>(operands.getSource());
228 emitc::ApplyOp srcPtr =
229 createPointerFromEmitcArray(loc, rewriter, srcArrayValue);
230
231 auto targetArrayValue =
232 cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
233 emitc::ApplyOp targetPtr =
234 createPointerFromEmitcArray(loc, rewriter, targetArrayValue);
235
236 emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create(
237 rewriter, loc, TypeRange{}, "memcpy",
239 targetPtr.getResult(), srcPtr.getResult(),
240 calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
241
242 rewriter.replaceOp(copyOp, memCpyCall.getResults());
243
244 return success();
245 }
246};
247
248struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
249 using OpConversionPattern::OpConversionPattern;
250
251 LogicalResult
252 matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
253 ConversionPatternRewriter &rewriter) const override {
254 MemRefType opTy = op.getType();
255 if (!op.getType().hasStaticShape()) {
256 return rewriter.notifyMatchFailure(
257 op.getLoc(), "cannot transform global with dynamic shape");
258 }
259
260 if (op.getAlignment().value_or(1) > 1) {
261 // TODO: Extend GlobalOp to specify alignment via the `alignas` specifier.
262 return rewriter.notifyMatchFailure(
263 op.getLoc(), "global variable with alignment requirement is "
264 "currently not supported");
265 }
266
267 Type resultTy = convertMemRefType(opTy, getTypeConverter());
268
269 if (!resultTy) {
270 return rewriter.notifyMatchFailure(op.getLoc(),
271 "cannot convert result type");
272 }
273
275 if (visibility != SymbolTable::Visibility::Public &&
276 visibility != SymbolTable::Visibility::Private) {
277 return rewriter.notifyMatchFailure(
278 op.getLoc(),
279 "only public and private visibility is currently supported");
280 }
281 // We are explicit in specifing the linkage because the default linkage
282 // for constants is different in C and C++.
283 bool staticSpecifier = visibility == SymbolTable::Visibility::Private;
284 bool externSpecifier = !staticSpecifier;
285
286 Attribute initialValue = operands.getInitialValueAttr();
287 if (opTy.getRank() == 0) {
288 auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
289 initialValue = elementsAttr.getSplatValue<Attribute>();
290 }
291 if (isa_and_present<UnitAttr>(initialValue))
292 initialValue = {};
293
294 rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
295 op, operands.getSymName(), resultTy, initialValue, externSpecifier,
296 staticSpecifier, operands.getConstant());
297 return success();
298 }
299};
300
301struct ConvertGetGlobal final
302 : public OpConversionPattern<memref::GetGlobalOp> {
303 using OpConversionPattern::OpConversionPattern;
304
305 LogicalResult
306 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
307 ConversionPatternRewriter &rewriter) const override {
308
309 MemRefType opTy = op.getType();
310 Type resultTy = convertMemRefType(opTy, getTypeConverter());
311
312 if (!resultTy) {
313 return rewriter.notifyMatchFailure(op.getLoc(),
314 "cannot convert result type");
315 }
316
317 if (opTy.getRank() == 0) {
318 emitc::LValueType lvalueType = emitc::LValueType::get(resultTy);
319 emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create(
320 rewriter, op.getLoc(), lvalueType, operands.getNameAttr());
321 emitc::PointerType pointerType = emitc::PointerType::get(resultTy);
322 rewriter.replaceOpWithNewOp<emitc::ApplyOp>(
323 op, pointerType, rewriter.getStringAttr("&"), globalLValue);
324 return success();
325 }
326 rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
327 operands.getNameAttr());
328 return success();
329 }
330};
331
332struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
333 using OpConversionPattern::OpConversionPattern;
334
335 LogicalResult
336 matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
337 ConversionPatternRewriter &rewriter) const override {
338
339 auto resultTy = getTypeConverter()->convertType(op.getType());
340 if (!resultTy) {
341 return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
342 }
343
344 auto arrayValue =
345 dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
346 if (!arrayValue) {
347 return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
348 }
349
350 auto subscript = emitc::SubscriptOp::create(
351 rewriter, op.getLoc(), arrayValue, operands.getIndices());
352
353 rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
354 return success();
355 }
356};
357
358struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
359 using OpConversionPattern::OpConversionPattern;
360
361 LogicalResult
362 matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
363 ConversionPatternRewriter &rewriter) const override {
364 auto arrayValue =
365 dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
366 if (!arrayValue) {
367 return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
368 }
369
370 auto subscript = emitc::SubscriptOp::create(
371 rewriter, op.getLoc(), arrayValue, operands.getIndices());
372 rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
373 operands.getValue());
374 return success();
375 }
376};
377} // namespace
378
380 typeConverter.addConversion(
381 [&](MemRefType memRefType) -> std::optional<Type> {
382 if (!isMemRefTypeLegalForEmitC(memRefType)) {
383 return {};
384 }
385 Type convertedElementType =
386 typeConverter.convertType(memRefType.getElementType());
387 if (!convertedElementType)
388 return {};
389 return emitc::ArrayType::get(memRefType.getShape(),
390 convertedElementType);
391 });
392
393 auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType,
394 ValueRange inputs,
395 Location loc) -> Value {
396 if (inputs.size() != 1)
397 return Value();
398
399 return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
400 .getResult(0);
401 };
402
403 typeConverter.addSourceMaterialization(materializeAsUnrealizedCast);
404 typeConverter.addTargetMaterialization(materializeAsUnrealizedCast);
405}
406
408 RewritePatternSet &patterns, const TypeConverter &converter) {
409 patterns.add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertGlobal,
410 ConvertGetGlobal, ConvertLoad, ConvertStore>(
411 converter, patterns.getContext());
412}
return success()
b getContext())
static bool isMemRefTypeLegalForEmitC(MemRefType memRefType)
constexpr const char * mallocFunctionName
constexpr const char * alignedAllocFunctionName
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
Visibility
An enumeration detailing the different visibility types that a symbol may have.
Definition SymbolTable.h:90
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Include the generated interface declarations.
void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, const TypeConverter &converter)
void registerConvertMemRefToEmitCInterface(DialectRegistry &registry)
LogicalResult matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override