25#include "llvm/ADT/STLExtras.h"
32 return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
33 memRefType.getRank() != 0 &&
34 !llvm::is_contained(memRefType.getShape(), 0);
44 void populateConvertToEmitCConversionPatterns(
45 ConversionTarget &
target, TypeConverter &typeConverter,
46 RewritePatternSet &
patterns)
const final {
55 dialect->addInterfaces<MemRefToEmitCDialectInterface>();
64struct ConvertAlloca final :
public OpConversionPattern<memref::AllocaOp> {
65 using OpConversionPattern::OpConversionPattern;
68 matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
69 ConversionPatternRewriter &rewriter)
const override {
71 if (!op.getType().hasStaticShape()) {
72 return rewriter.notifyMatchFailure(
73 op.getLoc(),
"cannot transform alloca with dynamic shape");
76 if (op.getAlignment().value_or(1) > 1) {
79 return rewriter.notifyMatchFailure(
80 op.getLoc(),
"cannot transform alloca with alignment requirement");
83 auto resultTy = getTypeConverter()->convertType(op.getType());
85 return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert type");
87 auto noInit = emitc::OpaqueAttr::get(
getContext(),
"");
88 rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
95 if (opTy.getRank() == 0) {
98 resultTy = typeConverter->convertType(opTy);
103static Value calculateMemrefTotalSizeBytes(
Location loc, MemRefType memrefType,
106 "incompatible memref type for EmitC conversion");
107 emitc::CallOpaqueOp elementSize = emitc::CallOpaqueOp::create(
108 builder, loc, emitc::SizeTType::get(builder.
getContext()),
111 {TypeAttr::get(memrefType.getElementType())}));
114 int64_t numElements = llvm::product_of(memrefType.getShape());
115 emitc::ConstantOp numElementsValue = emitc::ConstantOp::create(
116 builder, loc, indexType, builder.
getIndexAttr(numElements));
119 emitc::MulOp totalSizeBytes = emitc::MulOp::create(
120 builder, loc, sizeTType, elementSize.getResult(0), numElementsValue);
122 return totalSizeBytes.getResult();
129 emitc::ConstantOp zeroIndex = emitc::ConstantOp::create(
132 emitc::ArrayType arrayType = arrayValue.getType();
134 emitc::SubscriptOp subPtr =
136 emitc::ApplyOp
ptr = emitc::ApplyOp::create(
137 builder, loc, emitc::PointerType::get(arrayType.getElementType()),
143struct ConvertAlloc final :
public OpConversionPattern<memref::AllocOp> {
144 using OpConversionPattern::OpConversionPattern;
146 matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
147 ConversionPatternRewriter &rewriter)
const override {
148 Location loc = allocOp.getLoc();
149 MemRefType memrefType = allocOp.getType();
151 return rewriter.notifyMatchFailure(
152 loc,
"incompatible memref type for EmitC conversion");
155 Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
156 Type elementType = memrefType.getElementType();
157 IndexType indexType = rewriter.getIndexType();
158 emitc::CallOpaqueOp sizeofElementOp = emitc::CallOpaqueOp::create(
159 rewriter, loc, sizeTType, rewriter.getStringAttr(
"sizeof"),
161 ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)}));
163 int64_t numElements = 1;
164 for (int64_t dimSize : memrefType.getShape()) {
165 numElements *= dimSize;
167 Value numElementsValue = emitc::ConstantOp::create(
168 rewriter, loc, indexType, rewriter.getIndexAttr(numElements));
170 Value totalSizeBytes =
171 emitc::MulOp::create(rewriter, loc, sizeTType,
172 sizeofElementOp.getResult(0), numElementsValue);
174 emitc::CallOpaqueOp allocCall;
175 StringAttr allocFunctionName;
176 Value alignmentValue;
177 SmallVector<Value, 2> argsVec;
178 if (allocOp.getAlignment()) {
180 alignmentValue = emitc::ConstantOp::create(
181 rewriter, loc, sizeTType,
182 rewriter.getIntegerAttr(indexType,
183 allocOp.getAlignment().value_or(0)));
184 argsVec.push_back(alignmentValue);
189 argsVec.push_back(totalSizeBytes);
192 allocCall = emitc::CallOpaqueOp::create(
194 emitc::PointerType::get(
195 emitc::OpaqueType::get(rewriter.getContext(),
"void")),
196 allocFunctionName, args);
198 emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
199 emitc::CastOp castOp = emitc::CastOp::create(
200 rewriter, loc, targetPointerType, allocCall.getResult(0));
202 rewriter.replaceOp(allocOp, castOp);
207struct ConvertCopy final :
public OpConversionPattern<memref::CopyOp> {
208 using OpConversionPattern::OpConversionPattern;
211 matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
212 ConversionPatternRewriter &rewriter)
const override {
213 Location loc = copyOp.getLoc();
214 MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType());
215 MemRefType targetMemrefType =
216 cast<MemRefType>(copyOp.getTarget().getType());
219 return rewriter.notifyMatchFailure(
220 loc,
"incompatible source memref type for EmitC conversion");
223 return rewriter.notifyMatchFailure(
224 loc,
"incompatible target memref type for EmitC conversion");
227 cast<TypedValue<emitc::ArrayType>>(operands.getSource());
228 emitc::ApplyOp srcPtr =
229 createPointerFromEmitcArray(loc, rewriter, srcArrayValue);
231 auto targetArrayValue =
232 cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
233 emitc::ApplyOp targetPtr =
234 createPointerFromEmitcArray(loc, rewriter, targetArrayValue);
236 emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create(
239 targetPtr.getResult(), srcPtr.getResult(),
240 calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
242 rewriter.replaceOp(copyOp, memCpyCall.getResults());
248struct ConvertGlobal final :
public OpConversionPattern<memref::GlobalOp> {
249 using OpConversionPattern::OpConversionPattern;
252 matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
253 ConversionPatternRewriter &rewriter)
const override {
254 MemRefType opTy = op.getType();
255 if (!op.getType().hasStaticShape()) {
256 return rewriter.notifyMatchFailure(
257 op.getLoc(),
"cannot transform global with dynamic shape");
260 if (op.getAlignment().value_or(1) > 1) {
262 return rewriter.notifyMatchFailure(
263 op.getLoc(),
"global variable with alignment requirement is "
264 "currently not supported");
267 Type resultTy = convertMemRefType(opTy, getTypeConverter());
270 return rewriter.notifyMatchFailure(op.getLoc(),
271 "cannot convert result type");
275 if (visibility != SymbolTable::Visibility::Public &&
276 visibility != SymbolTable::Visibility::Private) {
277 return rewriter.notifyMatchFailure(
279 "only public and private visibility is currently supported");
283 bool staticSpecifier = visibility == SymbolTable::Visibility::Private;
284 bool externSpecifier = !staticSpecifier;
286 Attribute initialValue = operands.getInitialValueAttr();
287 if (opTy.getRank() == 0) {
288 auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
289 initialValue = elementsAttr.getSplatValue<Attribute>();
291 if (isa_and_present<UnitAttr>(initialValue))
294 rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
295 op, operands.getSymName(), resultTy, initialValue, externSpecifier,
296 staticSpecifier, operands.getConstant());
301struct ConvertGetGlobal final
302 :
public OpConversionPattern<memref::GetGlobalOp> {
303 using OpConversionPattern::OpConversionPattern;
306 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
307 ConversionPatternRewriter &rewriter)
const override {
309 MemRefType opTy = op.getType();
310 Type resultTy = convertMemRefType(opTy, getTypeConverter());
313 return rewriter.notifyMatchFailure(op.getLoc(),
314 "cannot convert result type");
317 if (opTy.getRank() == 0) {
318 emitc::LValueType lvalueType = emitc::LValueType::get(resultTy);
319 emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create(
320 rewriter, op.getLoc(), lvalueType, operands.getNameAttr());
321 emitc::PointerType pointerType = emitc::PointerType::get(resultTy);
322 rewriter.replaceOpWithNewOp<emitc::ApplyOp>(
323 op, pointerType, rewriter.getStringAttr(
"&"), globalLValue);
326 rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
327 operands.getNameAttr());
332struct ConvertLoad final :
public OpConversionPattern<memref::LoadOp> {
333 using OpConversionPattern::OpConversionPattern;
337 ConversionPatternRewriter &rewriter)
const override {
339 auto resultTy = getTypeConverter()->convertType(op.getType());
341 return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert type");
345 dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
347 return rewriter.notifyMatchFailure(op.getLoc(),
"expected array type");
350 auto subscript = emitc::SubscriptOp::create(
351 rewriter, op.getLoc(), arrayValue, operands.getIndices());
353 rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
358struct ConvertStore final :
public OpConversionPattern<memref::StoreOp> {
359 using OpConversionPattern::OpConversionPattern;
363 ConversionPatternRewriter &rewriter)
const override {
365 dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
367 return rewriter.notifyMatchFailure(op.getLoc(),
"expected array type");
370 auto subscript = emitc::SubscriptOp::create(
371 rewriter, op.getLoc(), arrayValue, operands.getIndices());
372 rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
373 operands.getValue());
380 typeConverter.addConversion(
381 [&](MemRefType memRefType) -> std::optional<Type> {
385 Type convertedElementType =
386 typeConverter.convertType(memRefType.getElementType());
387 if (!convertedElementType)
389 return emitc::ArrayType::get(memRefType.getShape(),
390 convertedElementType);
393 auto materializeAsUnrealizedCast = [](
OpBuilder &builder,
Type resultType,
396 if (inputs.size() != 1)
399 return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
403 typeConverter.addSourceMaterialization(materializeAsUnrealizedCast);
404 typeConverter.addTargetMaterialization(materializeAsUnrealizedCast);
409 patterns.add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertGlobal,
static bool isMemRefTypeLegalForEmitC(MemRefType memRefType)
constexpr const char * mallocFunctionName
constexpr const char * alignedAllocFunctionName
IntegerAttr getIndexAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ConvertToEmitCPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
Visibility
An enumeration detailing the different visibility types that a symbol may have.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, const TypeConverter &converter)
void registerConvertMemRefToEmitCInterface(DialectRegistry ®istry)
LogicalResult matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override