25#include "llvm/ADT/STLExtras.h"
32 return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
33 memRefType.getRank() != 0 &&
34 !llvm::is_contained(memRefType.getShape(), 0);
39struct MemRefToEmitCDialectInterface :
public ConvertToEmitCPatternInterface {
40 MemRefToEmitCDialectInterface(Dialect *dialect)
41 : ConvertToEmitCPatternInterface(dialect) {}
45 void populateConvertToEmitCConversionPatterns(
46 ConversionTarget &
target, TypeConverter &typeConverter,
47 RewritePatternSet &patterns)
const final {
56 dialect->addInterfaces<MemRefToEmitCDialectInterface>();
65struct ConvertAlloca final :
public OpConversionPattern<memref::AllocaOp> {
66 using OpConversionPattern::OpConversionPattern;
69 matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
70 ConversionPatternRewriter &rewriter)
const override {
72 if (!op.getType().hasStaticShape()) {
73 return rewriter.notifyMatchFailure(
74 op.getLoc(),
"cannot transform alloca with dynamic shape");
77 if (op.getAlignment().value_or(1) > 1) {
80 return rewriter.notifyMatchFailure(
81 op.getLoc(),
"cannot transform alloca with alignment requirement");
84 auto resultTy = getTypeConverter()->convertType(op.getType());
86 return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert type");
88 auto noInit = emitc::OpaqueAttr::get(
getContext(),
"");
89 rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
96 if (opTy.getRank() == 0) {
99 resultTy = typeConverter->convertType(opTy);
104static Value calculateMemrefTotalSizeBytes(
Location loc, MemRefType memrefType,
107 "incompatible memref type for EmitC conversion");
108 emitc::CallOpaqueOp elementSize = emitc::CallOpaqueOp::create(
109 builder, loc, emitc::SizeTType::get(builder.
getContext()),
112 {TypeAttr::get(memrefType.getElementType())}));
115 int64_t numElements = llvm::product_of(memrefType.getShape());
116 emitc::ConstantOp numElementsValue = emitc::ConstantOp::create(
117 builder, loc, indexType, builder.
getIndexAttr(numElements));
120 emitc::MulOp totalSizeBytes = emitc::MulOp::create(
121 builder, loc, sizeTType, elementSize.getResult(0), numElementsValue);
123 return totalSizeBytes.getResult();
126static emitc::AddressOfOp
130 emitc::ConstantOp zeroIndex = emitc::ConstantOp::create(
133 emitc::ArrayType arrayType = arrayValue.getType();
135 emitc::SubscriptOp subPtr =
137 emitc::AddressOfOp
ptr = emitc::AddressOfOp::create(
138 builder, loc, emitc::PointerType::get(arrayType.getElementType()),
146static Value stripPointerUnrealizedCast(
Value v) {
147 if (
auto cast = v.
getDefiningOp<UnrealizedConversionCastOp>())
148 if (cast.getNumOperands() == 1 &&
149 isa<emitc::PointerType>(cast.getOperand(0).getType()))
150 return cast.getOperand(0);
155 MemRefType memrefType,
164 ? emitc::ConstantOp::create(builder, idxType, builder.
getIndexAttr(0))
167 for (
auto [dim, idx] : llvm::zip(
shape.drop_front(),
indices.drop_front())) {
169 emitc::ConstantOp::create(builder, idxType, builder.
getIndexAttr(dim));
170 linearIndex = emitc::MulOp::create(builder, idxType, linearIndex, dimSize);
171 linearIndex = emitc::AddOp::create(builder, idxType, linearIndex, idx);
176struct ConvertAlloc final :
public OpConversionPattern<memref::AllocOp> {
177 using OpConversionPattern::OpConversionPattern;
179 matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
180 ConversionPatternRewriter &rewriter)
const override {
181 Location loc = allocOp.getLoc();
182 MemRefType memrefType = allocOp.getType();
184 return rewriter.notifyMatchFailure(
185 loc,
"incompatible memref type for EmitC conversion");
188 Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
189 Type elementType = memrefType.getElementType();
190 IndexType indexType = rewriter.getIndexType();
191 emitc::CallOpaqueOp sizeofElementOp = emitc::CallOpaqueOp::create(
192 rewriter, loc, sizeTType, rewriter.getStringAttr(
"sizeof"),
194 ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)}));
196 int64_t numElements = 1;
197 for (int64_t dimSize : memrefType.getShape()) {
198 numElements *= dimSize;
200 Value numElementsValue = emitc::ConstantOp::create(
201 rewriter, loc, indexType, rewriter.getIndexAttr(numElements));
203 Value totalSizeBytes =
204 emitc::MulOp::create(rewriter, loc, sizeTType,
205 sizeofElementOp.getResult(0), numElementsValue);
207 emitc::CallOpaqueOp allocCall;
208 StringAttr allocFunctionName;
209 Value alignmentValue;
210 SmallVector<Value, 2> argsVec;
211 if (allocOp.getAlignment()) {
213 alignmentValue = emitc::ConstantOp::create(
214 rewriter, loc, sizeTType,
215 rewriter.getIntegerAttr(indexType,
216 allocOp.getAlignment().value_or(0)));
217 argsVec.push_back(alignmentValue);
222 argsVec.push_back(totalSizeBytes);
225 allocCall = emitc::CallOpaqueOp::create(
227 emitc::PointerType::get(
228 emitc::OpaqueType::get(rewriter.getContext(),
"void")),
229 allocFunctionName, args);
231 emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
232 emitc::CastOp castOp = emitc::CastOp::create(
233 rewriter, loc, targetPointerType, allocCall.getResult(0));
235 rewriter.replaceOp(allocOp, castOp);
240struct ConvertCopy final :
public OpConversionPattern<memref::CopyOp> {
241 using OpConversionPattern::OpConversionPattern;
244 matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
245 ConversionPatternRewriter &rewriter)
const override {
246 Location loc = copyOp.getLoc();
247 MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType());
248 MemRefType targetMemrefType =
249 cast<MemRefType>(copyOp.getTarget().getType());
252 return rewriter.notifyMatchFailure(
253 loc,
"incompatible source memref type for EmitC conversion");
256 return rewriter.notifyMatchFailure(
257 loc,
"incompatible target memref type for EmitC conversion");
260 cast<TypedValue<emitc::ArrayType>>(operands.getSource());
261 emitc::AddressOfOp srcPtr =
262 createPointerFromEmitcArray(loc, rewriter, srcArrayValue);
264 auto targetArrayValue =
265 cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
266 emitc::AddressOfOp targetPtr =
267 createPointerFromEmitcArray(loc, rewriter, targetArrayValue);
269 emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create(
272 targetPtr.getResult(), srcPtr.getResult(),
273 calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
275 rewriter.replaceOp(copyOp, memCpyCall.getResults());
281struct ConvertGlobal final :
public OpConversionPattern<memref::GlobalOp> {
282 using OpConversionPattern::OpConversionPattern;
285 matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
286 ConversionPatternRewriter &rewriter)
const override {
287 MemRefType opTy = op.getType();
288 if (!op.getType().hasStaticShape()) {
289 return rewriter.notifyMatchFailure(
290 op.getLoc(),
"cannot transform global with dynamic shape");
293 if (op.getAlignment().value_or(1) > 1) {
295 return rewriter.notifyMatchFailure(
296 op.getLoc(),
"global variable with alignment requirement is "
297 "currently not supported");
300 Type resultTy = convertMemRefType(opTy, getTypeConverter());
303 return rewriter.notifyMatchFailure(op.getLoc(),
304 "cannot convert result type");
308 if (visibility != SymbolTable::Visibility::Public &&
309 visibility != SymbolTable::Visibility::Private) {
310 return rewriter.notifyMatchFailure(
312 "only public and private visibility is currently supported");
316 bool staticSpecifier = visibility == SymbolTable::Visibility::Private;
317 bool externSpecifier = !staticSpecifier;
319 Attribute initialValue = operands.getInitialValueAttr();
320 if (opTy.getRank() == 0) {
322 if (std::optional<Attribute> initValueAttr = op.getInitialValue()) {
323 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(*initValueAttr)) {
324 initialValue = elementsAttr.getSplatValue<Attribute>();
328 if (isa_and_present<UnitAttr>(initialValue))
331 rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
332 op, operands.getSymName(), resultTy, initialValue, externSpecifier,
333 staticSpecifier, operands.getConstant());
338struct ConvertGetGlobal final
339 :
public OpConversionPattern<memref::GetGlobalOp> {
340 using OpConversionPattern::OpConversionPattern;
343 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
344 ConversionPatternRewriter &rewriter)
const override {
346 MemRefType opTy = op.getType();
347 Type resultTy = convertMemRefType(opTy, getTypeConverter());
350 return rewriter.notifyMatchFailure(op.getLoc(),
351 "cannot convert result type");
354 if (opTy.getRank() == 0) {
355 emitc::LValueType lvalueType = emitc::LValueType::get(resultTy);
356 emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create(
357 rewriter, op.getLoc(), lvalueType, operands.getNameAttr());
358 emitc::PointerType pointerType = emitc::PointerType::get(resultTy);
359 rewriter.replaceOpWithNewOp<emitc::AddressOfOp>(op, pointerType,
363 rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
364 operands.getNameAttr());
369struct ConvertLoad final :
public OpConversionPattern<memref::LoadOp> {
370 using OpConversionPattern::OpConversionPattern;
374 ConversionPatternRewriter &rewriter)
const override {
375 Location loc = op.getLoc();
376 auto resultTy = getTypeConverter()->convertType(op.getType());
378 return rewriter.notifyMatchFailure(loc,
"cannot convert type");
382 dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
383 Value strippedPtr = stripPointerUnrealizedCast(operands.getMemref());
384 if (!strippedPtr && arrayValue) {
385 auto subscript = emitc::SubscriptOp::create(rewriter, loc, arrayValue,
386 operands.getIndices());
388 rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
393 return rewriter.notifyMatchFailure(loc,
"expected array or pointer type");
394 MemRefType opMemrefType = cast<MemRefType>(op.getMemref().getType());
397 ImplicitLocOpBuilder
b(loc, rewriter);
398 Value linearIndex = computeRowMajorLinearIndex(
b, opMemrefType,
indices);
399 auto typedPtr = cast<TypedValue<emitc::PointerType>>(strippedPtr);
401 emitc::SubscriptOp::create(rewriter, loc, typedPtr, linearIndex);
403 rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
408struct ConvertStore final :
public OpConversionPattern<memref::StoreOp> {
409 using OpConversionPattern::OpConversionPattern;
413 ConversionPatternRewriter &rewriter)
const override {
414 Location loc = op.getLoc();
416 dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
417 Value strippedPtr = stripPointerUnrealizedCast(operands.getMemref());
418 if (!strippedPtr && arrayValue) {
419 auto subscript = emitc::SubscriptOp::create(rewriter, loc, arrayValue,
420 operands.getIndices());
421 rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
422 operands.getValue());
427 return rewriter.notifyMatchFailure(loc,
"expected array or pointer type");
428 MemRefType opMemrefType = cast<MemRefType>(op.getMemref().getType());
431 ImplicitLocOpBuilder
b(loc, rewriter);
432 Value linearIndex = computeRowMajorLinearIndex(
b, opMemrefType,
indices);
433 auto typedPtr = cast<TypedValue<emitc::PointerType>>(strippedPtr);
435 emitc::SubscriptOp::create(rewriter, loc, typedPtr, linearIndex);
437 rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
438 operands.getValue());
446 typeConverter.addConversion(
447 [&](MemRefType memRefType) -> std::optional<Type> {
451 Type convertedElementType =
452 typeConverter.convertType(memRefType.getElementType());
453 if (!convertedElementType)
455 return emitc::ArrayType::get(memRefType.getShape(),
456 convertedElementType);
459 auto materializeAsUnrealizedCast = [](
OpBuilder &builder,
Type resultType,
462 if (inputs.size() != 1)
465 return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
469 typeConverter.addSourceMaterialization(materializeAsUnrealizedCast);
470 typeConverter.addTargetMaterialization(materializeAsUnrealizedCast);
475 patterns.
add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertGlobal,
static bool isMemRefTypeLegalForEmitC(MemRefType memRefType)
constexpr const char * mallocFunctionName
constexpr const char * alignedAllocFunctionName
IntegerAttr getIndexAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
Visibility
An enumeration detailing the different visibility types that a symbol may have.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Include the generated interface declarations.
void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, const TypeConverter &converter)
void registerConvertMemRefToEmitCInterface(DialectRegistry ®istry)
LogicalResult matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override