25#include "llvm/ADT/STLExtras.h"
32 return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
33 memRefType.getRank() != 0 &&
34 !llvm::is_contained(memRefType.getShape(), 0);
39struct MemRefToEmitCDialectInterface :
public ConvertToEmitCPatternInterface {
40 MemRefToEmitCDialectInterface(Dialect *dialect)
41 : ConvertToEmitCPatternInterface(dialect) {}
45 void populateConvertToEmitCConversionPatterns(
46 ConversionTarget &
target, TypeConverter &typeConverter,
47 RewritePatternSet &patterns, std::optional<bool> lowerToCpp)
const final {
55 dialect->addInterfaces<MemRefToEmitCDialectInterface>();
64struct ConvertAlloca final :
public OpConversionPattern<memref::AllocaOp> {
65 using OpConversionPattern::OpConversionPattern;
68 matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
69 ConversionPatternRewriter &rewriter)
const override {
71 if (!op.getType().hasStaticShape()) {
72 return rewriter.notifyMatchFailure(
73 op.getLoc(),
"cannot transform alloca with dynamic shape");
76 if (op.getAlignment().value_or(1) > 1) {
79 return rewriter.notifyMatchFailure(
80 op.getLoc(),
"cannot transform alloca with alignment requirement");
83 auto resultTy = getTypeConverter()->convertType(op.getType());
85 return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert type");
87 auto noInit = emitc::OpaqueAttr::get(
getContext(),
"");
88 rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
95 if (opTy.getRank() == 0) {
98 resultTy = typeConverter->convertType(opTy);
103static Value calculateMemrefTotalSizeBytes(
Location loc, MemRefType memrefType,
105 Type convertedElementType) {
107 "incompatible memref type for EmitC conversion");
109 emitc::CallOpaqueOp elementSize = emitc::CallOpaqueOp::create(
110 builder, loc, emitc::SizeTType::get(builder.
getContext()),
113 {TypeAttr::get(convertedElementType)}));
116 int64_t numElements = llvm::product_of(memrefType.getShape());
117 emitc::ConstantOp numElementsValue = emitc::ConstantOp::create(
118 builder, loc, indexType, builder.
getIndexAttr(numElements));
121 emitc::MulOp totalSizeBytes = emitc::MulOp::create(
122 builder, loc, sizeTType, elementSize.getResult(0), numElementsValue);
124 return totalSizeBytes.getResult();
127static emitc::AddressOfOp
131 emitc::ConstantOp zeroIndex = emitc::ConstantOp::create(
134 emitc::ArrayType arrayType = arrayValue.getType();
136 emitc::SubscriptOp subPtr =
138 emitc::AddressOfOp
ptr = emitc::AddressOfOp::create(
139 builder, loc, emitc::PointerType::get(arrayType.getElementType()),
147static Value stripPointerUnrealizedCast(
Value v) {
148 if (
auto cast = v.
getDefiningOp<UnrealizedConversionCastOp>())
149 if (cast.getNumOperands() == 1 &&
150 isa<emitc::PointerType>(cast.getOperand(0).getType()))
151 return cast.getOperand(0);
156 MemRefType memrefType,
165 ? emitc::ConstantOp::create(builder, idxType, builder.
getIndexAttr(0))
168 for (
auto [dim, idx] : llvm::zip(
shape.drop_front(),
indices.drop_front())) {
170 emitc::ConstantOp::create(builder, idxType, builder.
getIndexAttr(dim));
171 linearIndex = emitc::MulOp::create(builder, idxType, linearIndex, dimSize);
172 linearIndex = emitc::AddOp::create(builder, idxType, linearIndex, idx);
177struct ConvertAlloc final :
public OpConversionPattern<memref::AllocOp> {
178 using OpConversionPattern::OpConversionPattern;
180 matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
181 ConversionPatternRewriter &rewriter)
const override {
182 Location loc = allocOp.getLoc();
183 MemRefType memrefType = allocOp.getType();
185 return rewriter.notifyMatchFailure(
186 loc,
"incompatible memref type for EmitC conversion");
189 Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
191 getTypeConverter()->convertType(memrefType.getElementType());
193 return rewriter.notifyMatchFailure(
194 loc,
"failed to convert memref element type");
196 IndexType indexType = rewriter.getIndexType();
197 Value totalSizeBytes =
198 calculateMemrefTotalSizeBytes(loc, memrefType, rewriter, elementType);
200 emitc::CallOpaqueOp allocCall;
201 StringAttr allocFunctionName;
202 Value alignmentValue;
203 SmallVector<Value, 2> argsVec;
204 if (allocOp.getAlignment()) {
206 alignmentValue = emitc::ConstantOp::create(
207 rewriter, loc, sizeTType,
208 rewriter.getIntegerAttr(indexType,
209 allocOp.getAlignment().value_or(0)));
210 argsVec.push_back(alignmentValue);
215 argsVec.push_back(totalSizeBytes);
218 allocCall = emitc::CallOpaqueOp::create(
220 emitc::PointerType::get(
221 emitc::OpaqueType::get(rewriter.getContext(),
"void")),
222 allocFunctionName, args);
224 emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
225 emitc::CastOp castOp = emitc::CastOp::create(
226 rewriter, loc, targetPointerType, allocCall.getResult(0));
228 rewriter.replaceOp(allocOp, castOp);
233struct ConvertDealloc final :
public OpConversionPattern<memref::DeallocOp> {
234 using OpConversionPattern::OpConversionPattern;
237 matchAndRewrite(memref::DeallocOp deallocOp, OpAdaptor operands,
238 ConversionPatternRewriter &rewriter)
const override {
239 Location loc = deallocOp.getLoc();
243 Value strippedPtr = stripPointerUnrealizedCast(operands.getMemref());
245 return rewriter.notifyMatchFailure(
246 loc,
"expected pointer-backed memref for EmitC deallocation");
252 Type opaqueVoidPtrType = emitc::PointerType::get(
253 emitc::OpaqueType::get(rewriter.getContext(),
"void"));
255 emitc::CastOp::create(rewriter, loc, opaqueVoidPtrType, strippedPtr);
256 emitc::CallOpaqueOp freeCall = emitc::CallOpaqueOp::create(
259 rewriter.replaceOp(deallocOp, freeCall.getResults());
264struct ConvertCopy final :
public OpConversionPattern<memref::CopyOp> {
265 using OpConversionPattern::OpConversionPattern;
268 matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
269 ConversionPatternRewriter &rewriter)
const override {
270 Location loc = copyOp.getLoc();
271 MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType());
272 MemRefType targetMemrefType =
273 cast<MemRefType>(copyOp.getTarget().getType());
276 return rewriter.notifyMatchFailure(
277 loc,
"incompatible source memref type for EmitC conversion");
280 return rewriter.notifyMatchFailure(
281 loc,
"incompatible target memref type for EmitC conversion");
284 cast<TypedValue<emitc::ArrayType>>(operands.getSource());
285 emitc::AddressOfOp srcPtr =
286 createPointerFromEmitcArray(loc, rewriter, srcArrayValue);
288 auto targetArrayValue =
289 cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
290 emitc::AddressOfOp targetPtr =
291 createPointerFromEmitcArray(loc, rewriter, targetArrayValue);
293 Type convertedElementType =
294 getTypeConverter()->convertType(srcMemrefType.getElementType());
295 if (!convertedElementType) {
296 return rewriter.notifyMatchFailure(
297 loc,
"failed to convert memref element type");
299 Value totalSizeInBytes = calculateMemrefTotalSizeBytes(
300 loc, srcMemrefType, rewriter, convertedElementType);
301 emitc::CallOpaqueOp memCpyCall =
302 emitc::CallOpaqueOp::create(rewriter, loc,
TypeRange{},
"memcpy",
304 targetPtr.getResult(),
309 rewriter.replaceOp(copyOp, memCpyCall.getResults());
315struct ConvertGlobal final :
public OpConversionPattern<memref::GlobalOp> {
316 using OpConversionPattern::OpConversionPattern;
319 matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
320 ConversionPatternRewriter &rewriter)
const override {
321 MemRefType opTy = op.getType();
322 if (!op.getType().hasStaticShape()) {
323 return rewriter.notifyMatchFailure(
324 op.getLoc(),
"cannot transform global with dynamic shape");
327 if (op.getAlignment().value_or(1) > 1) {
329 return rewriter.notifyMatchFailure(
330 op.getLoc(),
"global variable with alignment requirement is "
331 "currently not supported");
334 Type resultTy = convertMemRefType(opTy, getTypeConverter());
337 return rewriter.notifyMatchFailure(op.getLoc(),
338 "cannot convert result type");
342 if (visibility != SymbolTable::Visibility::Public &&
343 visibility != SymbolTable::Visibility::Private) {
344 return rewriter.notifyMatchFailure(
346 "only public and private visibility is currently supported");
350 bool staticSpecifier = visibility == SymbolTable::Visibility::Private;
351 bool externSpecifier = !staticSpecifier;
353 Attribute initialValue = operands.getInitialValueAttr();
354 if (opTy.getRank() == 0) {
356 if (std::optional<Attribute> initValueAttr = op.getInitialValue()) {
357 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(*initValueAttr)) {
358 initialValue = elementsAttr.getSplatValue<Attribute>();
362 if (isa_and_present<UnitAttr>(initialValue))
365 rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
366 op, operands.getSymName(), resultTy, initialValue, externSpecifier,
367 staticSpecifier, operands.getConstant());
372struct ConvertGetGlobal final
373 :
public OpConversionPattern<memref::GetGlobalOp> {
374 using OpConversionPattern::OpConversionPattern;
377 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
378 ConversionPatternRewriter &rewriter)
const override {
380 MemRefType opTy = op.getType();
381 Type resultTy = convertMemRefType(opTy, getTypeConverter());
384 return rewriter.notifyMatchFailure(op.getLoc(),
385 "cannot convert result type");
388 if (opTy.getRank() == 0) {
389 emitc::LValueType lvalueType = emitc::LValueType::get(resultTy);
390 emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create(
391 rewriter, op.getLoc(), lvalueType, operands.getNameAttr());
392 emitc::PointerType pointerType = emitc::PointerType::get(resultTy);
393 rewriter.replaceOpWithNewOp<emitc::AddressOfOp>(op, pointerType,
397 rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
398 operands.getNameAttr());
403struct ConvertLoad final :
public OpConversionPattern<memref::LoadOp> {
404 using OpConversionPattern::OpConversionPattern;
408 ConversionPatternRewriter &rewriter)
const override {
409 Location loc = op.getLoc();
410 auto resultTy = getTypeConverter()->convertType(op.getType());
412 return rewriter.notifyMatchFailure(loc,
"cannot convert type");
416 dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
417 Value strippedPtr = stripPointerUnrealizedCast(operands.getMemref());
418 if (!strippedPtr && arrayValue) {
419 auto subscript = emitc::SubscriptOp::create(rewriter, loc, arrayValue,
420 operands.getIndices());
422 rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
427 return rewriter.notifyMatchFailure(loc,
"expected array or pointer type");
428 MemRefType opMemrefType = cast<MemRefType>(op.getMemref().getType());
431 ImplicitLocOpBuilder
b(loc, rewriter);
432 Value linearIndex = computeRowMajorLinearIndex(
b, opMemrefType,
indices);
433 auto typedPtr = cast<TypedValue<emitc::PointerType>>(strippedPtr);
435 emitc::SubscriptOp::create(rewriter, loc, typedPtr, linearIndex);
437 rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
442struct ConvertStore final :
public OpConversionPattern<memref::StoreOp> {
443 using OpConversionPattern::OpConversionPattern;
447 ConversionPatternRewriter &rewriter)
const override {
448 Location loc = op.getLoc();
450 dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
451 Value strippedPtr = stripPointerUnrealizedCast(operands.getMemref());
452 if (!strippedPtr && arrayValue) {
453 auto subscript = emitc::SubscriptOp::create(rewriter, loc, arrayValue,
454 operands.getIndices());
455 rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
456 operands.getValue());
461 return rewriter.notifyMatchFailure(loc,
"expected array or pointer type");
462 MemRefType opMemrefType = cast<MemRefType>(op.getMemref().getType());
465 ImplicitLocOpBuilder
b(loc, rewriter);
466 Value linearIndex = computeRowMajorLinearIndex(
b, opMemrefType,
indices);
467 auto typedPtr = cast<TypedValue<emitc::PointerType>>(strippedPtr);
469 emitc::SubscriptOp::create(rewriter, loc, typedPtr, linearIndex);
471 rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
472 operands.getValue());
481 patterns.
add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertDealloc,
static bool isMemRefTypeLegalForEmitC(MemRefType memRefType)
constexpr const char * mallocFunctionName
constexpr const char * freeFunctionName
constexpr const char * alignedAllocFunctionName
IntegerAttr getIndexAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
Visibility
An enumeration detailing the different visibility types that a symbol may have.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, const TypeConverter &converter)
void registerConvertMemRefToEmitCInterface(DialectRegistry ®istry)
LogicalResult matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override