MLIR 23.0.0git
MemRefToEmitC.cpp
Go to the documentation of this file.
1//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert memref ops into emitc ops.
10//
11//===----------------------------------------------------------------------===//
12
14
18#include "mlir/IR/Builders.h"
20#include "mlir/IR/Diagnostics.h"
22#include "mlir/IR/TypeRange.h"
23#include "mlir/IR/Value.h"
25#include "llvm/ADT/STLExtras.h"
26#include <cstdint>
27#include <numeric>
28
29using namespace mlir;
30
31static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) {
32 return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
33 memRefType.getRank() != 0 &&
34 !llvm::is_contained(memRefType.getShape(), 0);
35}
36
37namespace {
38/// Implement the interface to convert MemRef to EmitC.
39struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
40 MemRefToEmitCDialectInterface(Dialect *dialect)
41 : ConvertToEmitCPatternInterface(dialect) {}
42
43 /// Hook for derived dialect interface to provide conversion patterns
44 /// and mark dialect legal for the conversion target.
45 void populateConvertToEmitCConversionPatterns(
46 ConversionTarget &target, TypeConverter &typeConverter,
47 RewritePatternSet &patterns) const final {
49 populateMemRefToEmitCConversionPatterns(patterns, typeConverter);
50 }
51};
52} // namespace
53
55 registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
56 dialect->addInterfaces<MemRefToEmitCDialectInterface>();
57 });
58}
59
60//===----------------------------------------------------------------------===//
61// Conversion Patterns
62//===----------------------------------------------------------------------===//
63
64namespace {
65struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
66 using OpConversionPattern::OpConversionPattern;
67
68 LogicalResult
69 matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
70 ConversionPatternRewriter &rewriter) const override {
71
72 if (!op.getType().hasStaticShape()) {
73 return rewriter.notifyMatchFailure(
74 op.getLoc(), "cannot transform alloca with dynamic shape");
75 }
76
77 if (op.getAlignment().value_or(1) > 1) {
78 // TODO: Allow alignment if it is not more than the natural alignment
79 // of the C array.
80 return rewriter.notifyMatchFailure(
81 op.getLoc(), "cannot transform alloca with alignment requirement");
82 }
83
84 auto resultTy = getTypeConverter()->convertType(op.getType());
85 if (!resultTy) {
86 return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
87 }
88 auto noInit = emitc::OpaqueAttr::get(getContext(), "");
89 rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
90 return success();
91 }
92};
93
94Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
95 Type resultTy;
96 if (opTy.getRank() == 0) {
97 resultTy = typeConverter->convertType(mlir::getElementTypeOrSelf(opTy));
98 } else {
99 resultTy = typeConverter->convertType(opTy);
100 }
101 return resultTy;
102}
103
104static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
105 OpBuilder &builder) {
106 assert(isMemRefTypeLegalForEmitC(memrefType) &&
107 "incompatible memref type for EmitC conversion");
108 emitc::CallOpaqueOp elementSize = emitc::CallOpaqueOp::create(
109 builder, loc, emitc::SizeTType::get(builder.getContext()),
110 builder.getStringAttr("sizeof"), ValueRange{},
111 ArrayAttr::get(builder.getContext(),
112 {TypeAttr::get(memrefType.getElementType())}));
113
114 IndexType indexType = builder.getIndexType();
115 int64_t numElements = llvm::product_of(memrefType.getShape());
116 emitc::ConstantOp numElementsValue = emitc::ConstantOp::create(
117 builder, loc, indexType, builder.getIndexAttr(numElements));
118
119 Type sizeTType = emitc::SizeTType::get(builder.getContext());
120 emitc::MulOp totalSizeBytes = emitc::MulOp::create(
121 builder, loc, sizeTType, elementSize.getResult(0), numElementsValue);
122
123 return totalSizeBytes.getResult();
124}
125
126static emitc::AddressOfOp
127createPointerFromEmitcArray(Location loc, OpBuilder &builder,
128 TypedValue<emitc::ArrayType> arrayValue) {
129
130 emitc::ConstantOp zeroIndex = emitc::ConstantOp::create(
131 builder, loc, builder.getIndexType(), builder.getIndexAttr(0));
132
133 emitc::ArrayType arrayType = arrayValue.getType();
134 llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex);
135 emitc::SubscriptOp subPtr =
136 emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices));
137 emitc::AddressOfOp ptr = emitc::AddressOfOp::create(
138 builder, loc, emitc::PointerType::get(arrayType.getElementType()),
139 subPtr);
140
141 return ptr;
142}
143
144// If `v` is defined through an unrealized cast and the source of that cast
145// is `emitc.ptr`, return the pointer.
146static Value stripPointerUnrealizedCast(Value v) {
147 if (auto cast = v.getDefiningOp<UnrealizedConversionCastOp>())
148 if (cast.getNumOperands() == 1 &&
149 isa<emitc::PointerType>(cast.getOperand(0).getType()))
150 return cast.getOperand(0);
151 return Value();
152}
153
154static Value computeRowMajorLinearIndex(ImplicitLocOpBuilder &builder,
155 MemRefType memrefType,
157 ArrayRef<int64_t> shape = memrefType.getShape();
158
159 Type idxType =
160 indices.empty() ? builder.getIndexType() : indices[0].getType();
161
162 Value linearIndex =
163 indices.empty()
164 ? emitc::ConstantOp::create(builder, idxType, builder.getIndexAttr(0))
165 : indices[0];
166
167 for (auto [dim, idx] : llvm::zip(shape.drop_front(), indices.drop_front())) {
168 Value dimSize =
169 emitc::ConstantOp::create(builder, idxType, builder.getIndexAttr(dim));
170 linearIndex = emitc::MulOp::create(builder, idxType, linearIndex, dimSize);
171 linearIndex = emitc::AddOp::create(builder, idxType, linearIndex, idx);
172 }
173 return linearIndex;
174}
175
176struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
177 using OpConversionPattern::OpConversionPattern;
178 LogicalResult
179 matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
180 ConversionPatternRewriter &rewriter) const override {
181 Location loc = allocOp.getLoc();
182 MemRefType memrefType = allocOp.getType();
183 if (!isMemRefTypeLegalForEmitC(memrefType)) {
184 return rewriter.notifyMatchFailure(
185 loc, "incompatible memref type for EmitC conversion");
186 }
187
188 Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
189 Type elementType = memrefType.getElementType();
190 IndexType indexType = rewriter.getIndexType();
191 emitc::CallOpaqueOp sizeofElementOp = emitc::CallOpaqueOp::create(
192 rewriter, loc, sizeTType, rewriter.getStringAttr("sizeof"),
193 ValueRange{},
194 ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)}));
195
196 int64_t numElements = 1;
197 for (int64_t dimSize : memrefType.getShape()) {
198 numElements *= dimSize;
199 }
200 Value numElementsValue = emitc::ConstantOp::create(
201 rewriter, loc, indexType, rewriter.getIndexAttr(numElements));
202
203 Value totalSizeBytes =
204 emitc::MulOp::create(rewriter, loc, sizeTType,
205 sizeofElementOp.getResult(0), numElementsValue);
206
207 emitc::CallOpaqueOp allocCall;
208 StringAttr allocFunctionName;
209 Value alignmentValue;
210 SmallVector<Value, 2> argsVec;
211 if (allocOp.getAlignment()) {
212 allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName);
213 alignmentValue = emitc::ConstantOp::create(
214 rewriter, loc, sizeTType,
215 rewriter.getIntegerAttr(indexType,
216 allocOp.getAlignment().value_or(0)));
217 argsVec.push_back(alignmentValue);
218 } else {
219 allocFunctionName = rewriter.getStringAttr(mallocFunctionName);
220 }
221
222 argsVec.push_back(totalSizeBytes);
223 ValueRange args(argsVec);
224
225 allocCall = emitc::CallOpaqueOp::create(
226 rewriter, loc,
227 emitc::PointerType::get(
228 emitc::OpaqueType::get(rewriter.getContext(), "void")),
229 allocFunctionName, args);
230
231 emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
232 emitc::CastOp castOp = emitc::CastOp::create(
233 rewriter, loc, targetPointerType, allocCall.getResult(0));
234
235 rewriter.replaceOp(allocOp, castOp);
236 return success();
237 }
238};
239
240struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
241 using OpConversionPattern::OpConversionPattern;
242
243 LogicalResult
244 matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
245 ConversionPatternRewriter &rewriter) const override {
246 Location loc = copyOp.getLoc();
247 MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType());
248 MemRefType targetMemrefType =
249 cast<MemRefType>(copyOp.getTarget().getType());
250
251 if (!isMemRefTypeLegalForEmitC(srcMemrefType))
252 return rewriter.notifyMatchFailure(
253 loc, "incompatible source memref type for EmitC conversion");
254
255 if (!isMemRefTypeLegalForEmitC(targetMemrefType))
256 return rewriter.notifyMatchFailure(
257 loc, "incompatible target memref type for EmitC conversion");
258
259 auto srcArrayValue =
260 cast<TypedValue<emitc::ArrayType>>(operands.getSource());
261 emitc::AddressOfOp srcPtr =
262 createPointerFromEmitcArray(loc, rewriter, srcArrayValue);
263
264 auto targetArrayValue =
265 cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
266 emitc::AddressOfOp targetPtr =
267 createPointerFromEmitcArray(loc, rewriter, targetArrayValue);
268
269 emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create(
270 rewriter, loc, TypeRange{}, "memcpy",
272 targetPtr.getResult(), srcPtr.getResult(),
273 calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
274
275 rewriter.replaceOp(copyOp, memCpyCall.getResults());
276
277 return success();
278 }
279};
280
281struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
282 using OpConversionPattern::OpConversionPattern;
283
284 LogicalResult
285 matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
286 ConversionPatternRewriter &rewriter) const override {
287 MemRefType opTy = op.getType();
288 if (!op.getType().hasStaticShape()) {
289 return rewriter.notifyMatchFailure(
290 op.getLoc(), "cannot transform global with dynamic shape");
291 }
292
293 if (op.getAlignment().value_or(1) > 1) {
294 // TODO: Extend GlobalOp to specify alignment via the `alignas` specifier.
295 return rewriter.notifyMatchFailure(
296 op.getLoc(), "global variable with alignment requirement is "
297 "currently not supported");
298 }
299
300 Type resultTy = convertMemRefType(opTy, getTypeConverter());
301
302 if (!resultTy) {
303 return rewriter.notifyMatchFailure(op.getLoc(),
304 "cannot convert result type");
305 }
306
308 if (visibility != SymbolTable::Visibility::Public &&
309 visibility != SymbolTable::Visibility::Private) {
310 return rewriter.notifyMatchFailure(
311 op.getLoc(),
312 "only public and private visibility is currently supported");
313 }
314 // We are explicit in specifying the linkage because the default linkage
315 // for constants is different in C and C++.
316 bool staticSpecifier = visibility == SymbolTable::Visibility::Private;
317 bool externSpecifier = !staticSpecifier;
318
319 Attribute initialValue = operands.getInitialValueAttr();
320 if (opTy.getRank() == 0) {
321 // special case for `variable : memref<i32> = dense<-1>`
322 if (std::optional<Attribute> initValueAttr = op.getInitialValue()) {
323 if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(*initValueAttr)) {
324 initialValue = elementsAttr.getSplatValue<Attribute>();
325 }
326 }
327 }
328 if (isa_and_present<UnitAttr>(initialValue))
329 initialValue = {};
330
331 rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
332 op, operands.getSymName(), resultTy, initialValue, externSpecifier,
333 staticSpecifier, operands.getConstant());
334 return success();
335 }
336};
337
338struct ConvertGetGlobal final
339 : public OpConversionPattern<memref::GetGlobalOp> {
340 using OpConversionPattern::OpConversionPattern;
341
342 LogicalResult
343 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
344 ConversionPatternRewriter &rewriter) const override {
345
346 MemRefType opTy = op.getType();
347 Type resultTy = convertMemRefType(opTy, getTypeConverter());
348
349 if (!resultTy) {
350 return rewriter.notifyMatchFailure(op.getLoc(),
351 "cannot convert result type");
352 }
353
354 if (opTy.getRank() == 0) {
355 emitc::LValueType lvalueType = emitc::LValueType::get(resultTy);
356 emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create(
357 rewriter, op.getLoc(), lvalueType, operands.getNameAttr());
358 emitc::PointerType pointerType = emitc::PointerType::get(resultTy);
359 rewriter.replaceOpWithNewOp<emitc::AddressOfOp>(op, pointerType,
360 globalLValue);
361 return success();
362 }
363 rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
364 operands.getNameAttr());
365 return success();
366 }
367};
368
369struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
370 using OpConversionPattern::OpConversionPattern;
371
372 LogicalResult
373 matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
374 ConversionPatternRewriter &rewriter) const override {
375 Location loc = op.getLoc();
376 auto resultTy = getTypeConverter()->convertType(op.getType());
377 if (!resultTy) {
378 return rewriter.notifyMatchFailure(loc, "cannot convert type");
379 }
380
381 auto arrayValue =
382 dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
383 Value strippedPtr = stripPointerUnrealizedCast(operands.getMemref());
384 if (!strippedPtr && arrayValue) {
385 auto subscript = emitc::SubscriptOp::create(rewriter, loc, arrayValue,
386 operands.getIndices());
387
388 rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
389 return success();
390 }
391
392 if (!strippedPtr)
393 return rewriter.notifyMatchFailure(loc, "expected array or pointer type");
394 MemRefType opMemrefType = cast<MemRefType>(op.getMemref().getType());
395 ValueRange indices = operands.getIndices();
396
397 ImplicitLocOpBuilder b(loc, rewriter);
398 Value linearIndex = computeRowMajorLinearIndex(b, opMemrefType, indices);
399 auto typedPtr = cast<TypedValue<emitc::PointerType>>(strippedPtr);
400 auto subscript =
401 emitc::SubscriptOp::create(rewriter, loc, typedPtr, linearIndex);
402
403 rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
404 return success();
405 }
406};
407
408struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
409 using OpConversionPattern::OpConversionPattern;
410
411 LogicalResult
412 matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
413 ConversionPatternRewriter &rewriter) const override {
414 Location loc = op.getLoc();
415 auto arrayValue =
416 dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
417 Value strippedPtr = stripPointerUnrealizedCast(operands.getMemref());
418 if (!strippedPtr && arrayValue) {
419 auto subscript = emitc::SubscriptOp::create(rewriter, loc, arrayValue,
420 operands.getIndices());
421 rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
422 operands.getValue());
423 return success();
424 }
425
426 if (!strippedPtr)
427 return rewriter.notifyMatchFailure(loc, "expected array or pointer type");
428 MemRefType opMemrefType = cast<MemRefType>(op.getMemref().getType());
429 ValueRange indices = operands.getIndices();
430
431 ImplicitLocOpBuilder b(loc, rewriter);
432 Value linearIndex = computeRowMajorLinearIndex(b, opMemrefType, indices);
433 auto typedPtr = cast<TypedValue<emitc::PointerType>>(strippedPtr);
434 auto subscript =
435 emitc::SubscriptOp::create(rewriter, loc, typedPtr, linearIndex);
436
437 rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
438 operands.getValue());
439 return success();
440 }
441};
442
443} // namespace
444
446 typeConverter.addConversion(
447 [&](MemRefType memRefType) -> std::optional<Type> {
448 if (!isMemRefTypeLegalForEmitC(memRefType)) {
449 return {};
450 }
451 Type convertedElementType =
452 typeConverter.convertType(memRefType.getElementType());
453 if (!convertedElementType)
454 return {};
455 return emitc::ArrayType::get(memRefType.getShape(),
456 convertedElementType);
457 });
458
459 auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType,
460 ValueRange inputs,
461 Location loc) -> Value {
462 if (inputs.size() != 1)
463 return Value();
464
465 return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
466 .getResult(0);
467 };
468
469 typeConverter.addSourceMaterialization(materializeAsUnrealizedCast);
470 typeConverter.addTargetMaterialization(materializeAsUnrealizedCast);
471}
472
474 RewritePatternSet &patterns, const TypeConverter &converter) {
475 patterns.add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertGlobal,
476 ConvertGetGlobal, ConvertLoad, ConvertStore>(
477 converter, patterns.getContext());
478}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static bool isMemRefTypeLegalForEmitC(MemRefType memRefType)
constexpr const char * mallocFunctionName
constexpr const char * alignedAllocFunctionName
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
Visibility
An enumeration detailing the different visibility types that a symbol may have.
Definition SymbolTable.h:90
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Include the generated interface declarations.
void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, const TypeConverter &converter)
void registerConvertMemRefToEmitCInterface(DialectRegistry &registry)
LogicalResult matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override