MLIR 23.0.0git
MemRefToEmitC.cpp
Go to the documentation of this file.
1//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert memref ops into emitc ops.
10//
11//===----------------------------------------------------------------------===//
12
14
18#include "mlir/IR/Builders.h"
20#include "mlir/IR/Diagnostics.h"
22#include "mlir/IR/TypeRange.h"
23#include "mlir/IR/Value.h"
25#include "llvm/ADT/STLExtras.h"
26#include <cstdint>
27#include <numeric>
28
29using namespace mlir;
30
31static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) {
32 return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
33 memRefType.getRank() != 0 &&
34 !llvm::is_contained(memRefType.getShape(), 0);
35}
36
37namespace {
38/// Implement the interface to convert MemRef to EmitC.
39struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
40 MemRefToEmitCDialectInterface(Dialect *dialect)
41 : ConvertToEmitCPatternInterface(dialect) {}
42
43 /// Hook for derived dialect interface to provide conversion patterns
44 /// and mark dialect legal for the conversion target.
45 void populateConvertToEmitCConversionPatterns(
46 ConversionTarget &target, TypeConverter &typeConverter,
47 RewritePatternSet &patterns, std::optional<bool> lowerToCpp) const final {
48 populateMemRefToEmitCConversionPatterns(patterns, typeConverter);
49 }
50};
51} // namespace
52
54 registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
55 dialect->addInterfaces<MemRefToEmitCDialectInterface>();
56 });
57}
58
59//===----------------------------------------------------------------------===//
60// Conversion Patterns
61//===----------------------------------------------------------------------===//
62
63namespace {
64struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
65 using OpConversionPattern::OpConversionPattern;
66
67 LogicalResult
68 matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
69 ConversionPatternRewriter &rewriter) const override {
70
71 if (!op.getType().hasStaticShape()) {
72 return rewriter.notifyMatchFailure(
73 op.getLoc(), "cannot transform alloca with dynamic shape");
74 }
75
76 if (op.getAlignment().value_or(1) > 1) {
77 // TODO: Allow alignment if it is not more than the natural alignment
78 // of the C array.
79 return rewriter.notifyMatchFailure(
80 op.getLoc(), "cannot transform alloca with alignment requirement");
81 }
82
83 auto resultTy = getTypeConverter()->convertType(op.getType());
84 if (!resultTy) {
85 return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
86 }
87 auto noInit = emitc::OpaqueAttr::get(getContext(), "");
88 rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
89 return success();
90 }
91};
92
93Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
94 Type resultTy;
95 if (opTy.getRank() == 0) {
96 resultTy = typeConverter->convertType(mlir::getElementTypeOrSelf(opTy));
97 } else {
98 resultTy = typeConverter->convertType(opTy);
99 }
100 return resultTy;
101}
102
103static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
104 OpBuilder &builder,
105 Type convertedElementType) {
106 assert(isMemRefTypeLegalForEmitC(memrefType) &&
107 "incompatible memref type for EmitC conversion");
108
109 emitc::CallOpaqueOp elementSize = emitc::CallOpaqueOp::create(
110 builder, loc, emitc::SizeTType::get(builder.getContext()),
111 builder.getStringAttr("sizeof"), ValueRange{},
112 ArrayAttr::get(builder.getContext(),
113 {TypeAttr::get(convertedElementType)}));
114
115 IndexType indexType = builder.getIndexType();
116 int64_t numElements = llvm::product_of(memrefType.getShape());
117 emitc::ConstantOp numElementsValue = emitc::ConstantOp::create(
118 builder, loc, indexType, builder.getIndexAttr(numElements));
119
120 Type sizeTType = emitc::SizeTType::get(builder.getContext());
121 emitc::MulOp totalSizeBytes = emitc::MulOp::create(
122 builder, loc, sizeTType, elementSize.getResult(0), numElementsValue);
123
124 return totalSizeBytes.getResult();
125}
126
127static emitc::AddressOfOp
128createPointerFromEmitcArray(Location loc, OpBuilder &builder,
129 TypedValue<emitc::ArrayType> arrayValue) {
130
131 emitc::ConstantOp zeroIndex = emitc::ConstantOp::create(
132 builder, loc, builder.getIndexType(), builder.getIndexAttr(0));
133
134 emitc::ArrayType arrayType = arrayValue.getType();
135 llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex);
136 emitc::SubscriptOp subPtr =
137 emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices));
138 emitc::AddressOfOp ptr = emitc::AddressOfOp::create(
139 builder, loc, emitc::PointerType::get(arrayType.getElementType()),
140 subPtr);
141
142 return ptr;
143}
144
145// If `v` is defined through an unrealized cast and the source of that cast
146// is `emitc.ptr`, return the pointer.
147static Value stripPointerUnrealizedCast(Value v) {
148 if (auto cast = v.getDefiningOp<UnrealizedConversionCastOp>())
149 if (cast.getNumOperands() == 1 &&
150 isa<emitc::PointerType>(cast.getOperand(0).getType()))
151 return cast.getOperand(0);
152 return Value();
153}
154
155static Value computeRowMajorLinearIndex(ImplicitLocOpBuilder &builder,
156 MemRefType memrefType,
158 ArrayRef<int64_t> shape = memrefType.getShape();
159
160 Type idxType =
161 indices.empty() ? builder.getIndexType() : indices[0].getType();
162
163 Value linearIndex =
164 indices.empty()
165 ? emitc::ConstantOp::create(builder, idxType, builder.getIndexAttr(0))
166 : indices[0];
167
168 for (auto [dim, idx] : llvm::zip(shape.drop_front(), indices.drop_front())) {
169 Value dimSize =
170 emitc::ConstantOp::create(builder, idxType, builder.getIndexAttr(dim));
171 linearIndex = emitc::MulOp::create(builder, idxType, linearIndex, dimSize);
172 linearIndex = emitc::AddOp::create(builder, idxType, linearIndex, idx);
173 }
174 return linearIndex;
175}
176
177struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
178 using OpConversionPattern::OpConversionPattern;
179 LogicalResult
180 matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
181 ConversionPatternRewriter &rewriter) const override {
182 Location loc = allocOp.getLoc();
183 MemRefType memrefType = allocOp.getType();
184 if (!isMemRefTypeLegalForEmitC(memrefType)) {
185 return rewriter.notifyMatchFailure(
186 loc, "incompatible memref type for EmitC conversion");
187 }
188
189 Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
190 Type elementType =
191 getTypeConverter()->convertType(memrefType.getElementType());
192 if (!elementType) {
193 return rewriter.notifyMatchFailure(
194 loc, "failed to convert memref element type");
195 }
196 IndexType indexType = rewriter.getIndexType();
197 Value totalSizeBytes =
198 calculateMemrefTotalSizeBytes(loc, memrefType, rewriter, elementType);
199
200 emitc::CallOpaqueOp allocCall;
201 StringAttr allocFunctionName;
202 Value alignmentValue;
203 SmallVector<Value, 2> argsVec;
204 if (allocOp.getAlignment()) {
205 allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName);
206 alignmentValue = emitc::ConstantOp::create(
207 rewriter, loc, sizeTType,
208 rewriter.getIntegerAttr(indexType,
209 allocOp.getAlignment().value_or(0)));
210 argsVec.push_back(alignmentValue);
211 } else {
212 allocFunctionName = rewriter.getStringAttr(mallocFunctionName);
213 }
214
215 argsVec.push_back(totalSizeBytes);
216 ValueRange args(argsVec);
217
218 allocCall = emitc::CallOpaqueOp::create(
219 rewriter, loc,
220 emitc::PointerType::get(
221 emitc::OpaqueType::get(rewriter.getContext(), "void")),
222 allocFunctionName, args);
223
224 emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
225 emitc::CastOp castOp = emitc::CastOp::create(
226 rewriter, loc, targetPointerType, allocCall.getResult(0));
227
228 rewriter.replaceOp(allocOp, castOp);
229 return success();
230 }
231};
232
233struct ConvertDealloc final : public OpConversionPattern<memref::DeallocOp> {
234 using OpConversionPattern::OpConversionPattern;
235
236 LogicalResult
237 matchAndRewrite(memref::DeallocOp deallocOp, OpAdaptor operands,
238 ConversionPatternRewriter &rewriter) const override {
239 Location loc = deallocOp.getLoc();
240 // `free` can only be emitted when the dealloc operand is recoverable as an
241 // `emitc.ptr<T>`. In the current conversion, that happens via an
242 // unrealized_conversion_cast from the pointer-backed EmitC form.
243 Value strippedPtr = stripPointerUnrealizedCast(operands.getMemref());
244 if (!strippedPtr) {
245 return rewriter.notifyMatchFailure(
246 loc, "expected pointer-backed memref for EmitC deallocation");
247 }
248
249 // The allocation APIs used by MemRefToEmitC return `void *`, and `free`
250 // expects that same pointer type. Deallocation therefore only needs the
251 // recovered base pointer cast back to `void *` before calling `free`.
252 Type opaqueVoidPtrType = emitc::PointerType::get(
253 emitc::OpaqueType::get(rewriter.getContext(), "void"));
254 Value freeArg =
255 emitc::CastOp::create(rewriter, loc, opaqueVoidPtrType, strippedPtr);
256 emitc::CallOpaqueOp freeCall = emitc::CallOpaqueOp::create(
257 rewriter, loc, TypeRange{}, rewriter.getStringAttr(freeFunctionName),
258 ValueRange{freeArg});
259 rewriter.replaceOp(deallocOp, freeCall.getResults());
260 return success();
261 }
262};
263
264struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
265 using OpConversionPattern::OpConversionPattern;
266
267 LogicalResult
268 matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
269 ConversionPatternRewriter &rewriter) const override {
270 Location loc = copyOp.getLoc();
271 MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType());
272 MemRefType targetMemrefType =
273 cast<MemRefType>(copyOp.getTarget().getType());
274
275 if (!isMemRefTypeLegalForEmitC(srcMemrefType))
276 return rewriter.notifyMatchFailure(
277 loc, "incompatible source memref type for EmitC conversion");
278
279 if (!isMemRefTypeLegalForEmitC(targetMemrefType))
280 return rewriter.notifyMatchFailure(
281 loc, "incompatible target memref type for EmitC conversion");
282
283 auto srcArrayValue =
284 cast<TypedValue<emitc::ArrayType>>(operands.getSource());
285 emitc::AddressOfOp srcPtr =
286 createPointerFromEmitcArray(loc, rewriter, srcArrayValue);
287
288 auto targetArrayValue =
289 cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
290 emitc::AddressOfOp targetPtr =
291 createPointerFromEmitcArray(loc, rewriter, targetArrayValue);
292
293 Type convertedElementType =
294 getTypeConverter()->convertType(srcMemrefType.getElementType());
295 if (!convertedElementType) {
296 return rewriter.notifyMatchFailure(
297 loc, "failed to convert memref element type");
298 }
299 Value totalSizeInBytes = calculateMemrefTotalSizeBytes(
300 loc, srcMemrefType, rewriter, convertedElementType);
301 emitc::CallOpaqueOp memCpyCall =
302 emitc::CallOpaqueOp::create(rewriter, loc, TypeRange{}, "memcpy",
304 targetPtr.getResult(),
305 srcPtr.getResult(),
306 totalSizeInBytes,
307 });
308
309 rewriter.replaceOp(copyOp, memCpyCall.getResults());
310
311 return success();
312 }
313};
314
315struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
316 using OpConversionPattern::OpConversionPattern;
317
318 LogicalResult
319 matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
320 ConversionPatternRewriter &rewriter) const override {
321 MemRefType opTy = op.getType();
322 if (!op.getType().hasStaticShape()) {
323 return rewriter.notifyMatchFailure(
324 op.getLoc(), "cannot transform global with dynamic shape");
325 }
326
327 if (op.getAlignment().value_or(1) > 1) {
328 // TODO: Extend GlobalOp to specify alignment via the `alignas` specifier.
329 return rewriter.notifyMatchFailure(
330 op.getLoc(), "global variable with alignment requirement is "
331 "currently not supported");
332 }
333
334 Type resultTy = convertMemRefType(opTy, getTypeConverter());
335
336 if (!resultTy) {
337 return rewriter.notifyMatchFailure(op.getLoc(),
338 "cannot convert result type");
339 }
340
342 if (visibility != SymbolTable::Visibility::Public &&
343 visibility != SymbolTable::Visibility::Private) {
344 return rewriter.notifyMatchFailure(
345 op.getLoc(),
346 "only public and private visibility is currently supported");
347 }
348 // We are explicit in specifying the linkage because the default linkage
349 // for constants is different in C and C++.
350 bool staticSpecifier = visibility == SymbolTable::Visibility::Private;
351 bool externSpecifier = !staticSpecifier;
352
353 Attribute initialValue = operands.getInitialValueAttr();
354 if (opTy.getRank() == 0) {
355 // special case for `variable : memref<i32> = dense<-1>`
356 if (std::optional<Attribute> initValueAttr = op.getInitialValue()) {
357 if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(*initValueAttr)) {
358 initialValue = elementsAttr.getSplatValue<Attribute>();
359 }
360 }
361 }
362 if (isa_and_present<UnitAttr>(initialValue))
363 initialValue = {};
364
365 rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
366 op, operands.getSymName(), resultTy, initialValue, externSpecifier,
367 staticSpecifier, operands.getConstant());
368 return success();
369 }
370};
371
372struct ConvertGetGlobal final
373 : public OpConversionPattern<memref::GetGlobalOp> {
374 using OpConversionPattern::OpConversionPattern;
375
376 LogicalResult
377 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
378 ConversionPatternRewriter &rewriter) const override {
379
380 MemRefType opTy = op.getType();
381 Type resultTy = convertMemRefType(opTy, getTypeConverter());
382
383 if (!resultTy) {
384 return rewriter.notifyMatchFailure(op.getLoc(),
385 "cannot convert result type");
386 }
387
388 if (opTy.getRank() == 0) {
389 emitc::LValueType lvalueType = emitc::LValueType::get(resultTy);
390 emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create(
391 rewriter, op.getLoc(), lvalueType, operands.getNameAttr());
392 emitc::PointerType pointerType = emitc::PointerType::get(resultTy);
393 rewriter.replaceOpWithNewOp<emitc::AddressOfOp>(op, pointerType,
394 globalLValue);
395 return success();
396 }
397 rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
398 operands.getNameAttr());
399 return success();
400 }
401};
402
403struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
404 using OpConversionPattern::OpConversionPattern;
405
406 LogicalResult
407 matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
408 ConversionPatternRewriter &rewriter) const override {
409 Location loc = op.getLoc();
410 auto resultTy = getTypeConverter()->convertType(op.getType());
411 if (!resultTy) {
412 return rewriter.notifyMatchFailure(loc, "cannot convert type");
413 }
414
415 auto arrayValue =
416 dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
417 Value strippedPtr = stripPointerUnrealizedCast(operands.getMemref());
418 if (!strippedPtr && arrayValue) {
419 auto subscript = emitc::SubscriptOp::create(rewriter, loc, arrayValue,
420 operands.getIndices());
421
422 rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
423 return success();
424 }
425
426 if (!strippedPtr)
427 return rewriter.notifyMatchFailure(loc, "expected array or pointer type");
428 MemRefType opMemrefType = cast<MemRefType>(op.getMemref().getType());
429 ValueRange indices = operands.getIndices();
430
431 ImplicitLocOpBuilder b(loc, rewriter);
432 Value linearIndex = computeRowMajorLinearIndex(b, opMemrefType, indices);
433 auto typedPtr = cast<TypedValue<emitc::PointerType>>(strippedPtr);
434 auto subscript =
435 emitc::SubscriptOp::create(rewriter, loc, typedPtr, linearIndex);
436
437 rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
438 return success();
439 }
440};
441
442struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
443 using OpConversionPattern::OpConversionPattern;
444
445 LogicalResult
446 matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
447 ConversionPatternRewriter &rewriter) const override {
448 Location loc = op.getLoc();
449 auto arrayValue =
450 dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
451 Value strippedPtr = stripPointerUnrealizedCast(operands.getMemref());
452 if (!strippedPtr && arrayValue) {
453 auto subscript = emitc::SubscriptOp::create(rewriter, loc, arrayValue,
454 operands.getIndices());
455 rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
456 operands.getValue());
457 return success();
458 }
459
460 if (!strippedPtr)
461 return rewriter.notifyMatchFailure(loc, "expected array or pointer type");
462 MemRefType opMemrefType = cast<MemRefType>(op.getMemref().getType());
463 ValueRange indices = operands.getIndices();
464
465 ImplicitLocOpBuilder b(loc, rewriter);
466 Value linearIndex = computeRowMajorLinearIndex(b, opMemrefType, indices);
467 auto typedPtr = cast<TypedValue<emitc::PointerType>>(strippedPtr);
468 auto subscript =
469 emitc::SubscriptOp::create(rewriter, loc, typedPtr, linearIndex);
470
471 rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
472 operands.getValue());
473 return success();
474 }
475};
476
477} // namespace
478
480 RewritePatternSet &patterns, const TypeConverter &converter) {
481 patterns.add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertDealloc,
482 ConvertGlobal, ConvertGetGlobal, ConvertLoad, ConvertStore>(
483 converter, patterns.getContext());
484}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static bool isMemRefTypeLegalForEmitC(MemRefType memRefType)
constexpr const char * mallocFunctionName
constexpr const char * freeFunctionName
constexpr const char * alignedAllocFunctionName
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:267
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
Visibility
An enumeration detailing the different visibility types that a symbol may have.
Definition SymbolTable.h:90
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, const TypeConverter &converter)
void registerConvertMemRefToEmitCInterface(DialectRegistry &registry)
LogicalResult matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override