25#include "llvm/ADT/STLExtras.h"
32 return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
33 memRefType.getRank() != 0 &&
34 !llvm::is_contained(memRefType.getShape(), 0);
44 void populateConvertToEmitCConversionPatterns(
45 ConversionTarget &
target, TypeConverter &typeConverter,
46 RewritePatternSet &patterns)
const final {
55 dialect->addInterfaces<MemRefToEmitCDialectInterface>();
64struct ConvertAlloca final :
public OpConversionPattern<memref::AllocaOp> {
65 using OpConversionPattern::OpConversionPattern;
68 matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
69 ConversionPatternRewriter &rewriter)
const override {
71 if (!op.getType().hasStaticShape()) {
72 return rewriter.notifyMatchFailure(
73 op.getLoc(),
"cannot transform alloca with dynamic shape");
76 if (op.getAlignment().value_or(1) > 1) {
79 return rewriter.notifyMatchFailure(
80 op.getLoc(),
"cannot transform alloca with alignment requirement");
83 auto resultTy = getTypeConverter()->convertType(op.getType());
85 return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert type");
87 auto noInit = emitc::OpaqueAttr::get(
getContext(),
"");
88 rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
95 if (opTy.getRank() == 0) {
98 resultTy = typeConverter->convertType(opTy);
103static Value calculateMemrefTotalSizeBytes(
Location loc, MemRefType memrefType,
106 "incompatible memref type for EmitC conversion");
107 emitc::CallOpaqueOp elementSize = emitc::CallOpaqueOp::create(
108 builder, loc, emitc::SizeTType::get(builder.
getContext()),
111 {TypeAttr::get(memrefType.getElementType())}));
114 int64_t numElements = llvm::product_of(memrefType.getShape());
115 emitc::ConstantOp numElementsValue = emitc::ConstantOp::create(
116 builder, loc, indexType, builder.
getIndexAttr(numElements));
119 emitc::MulOp totalSizeBytes = emitc::MulOp::create(
120 builder, loc, sizeTType, elementSize.getResult(0), numElementsValue);
122 return totalSizeBytes.getResult();
125static emitc::AddressOfOp
129 emitc::ConstantOp zeroIndex = emitc::ConstantOp::create(
132 emitc::ArrayType arrayType = arrayValue.getType();
134 emitc::SubscriptOp subPtr =
136 emitc::AddressOfOp
ptr = emitc::AddressOfOp::create(
137 builder, loc, emitc::PointerType::get(arrayType.getElementType()),
143struct ConvertAlloc final :
public OpConversionPattern<memref::AllocOp> {
144 using OpConversionPattern::OpConversionPattern;
146 matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
147 ConversionPatternRewriter &rewriter)
const override {
148 Location loc = allocOp.getLoc();
149 MemRefType memrefType = allocOp.getType();
151 return rewriter.notifyMatchFailure(
152 loc,
"incompatible memref type for EmitC conversion");
155 Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
156 Type elementType = memrefType.getElementType();
157 IndexType indexType = rewriter.getIndexType();
158 emitc::CallOpaqueOp sizeofElementOp = emitc::CallOpaqueOp::create(
159 rewriter, loc, sizeTType, rewriter.getStringAttr(
"sizeof"),
161 ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)}));
163 int64_t numElements = 1;
164 for (int64_t dimSize : memrefType.getShape()) {
165 numElements *= dimSize;
167 Value numElementsValue = emitc::ConstantOp::create(
168 rewriter, loc, indexType, rewriter.getIndexAttr(numElements));
170 Value totalSizeBytes =
171 emitc::MulOp::create(rewriter, loc, sizeTType,
172 sizeofElementOp.getResult(0), numElementsValue);
174 emitc::CallOpaqueOp allocCall;
175 StringAttr allocFunctionName;
176 Value alignmentValue;
177 SmallVector<Value, 2> argsVec;
178 if (allocOp.getAlignment()) {
180 alignmentValue = emitc::ConstantOp::create(
181 rewriter, loc, sizeTType,
182 rewriter.getIntegerAttr(indexType,
183 allocOp.getAlignment().value_or(0)));
184 argsVec.push_back(alignmentValue);
189 argsVec.push_back(totalSizeBytes);
192 allocCall = emitc::CallOpaqueOp::create(
194 emitc::PointerType::get(
195 emitc::OpaqueType::get(rewriter.getContext(),
"void")),
196 allocFunctionName, args);
198 emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
199 emitc::CastOp castOp = emitc::CastOp::create(
200 rewriter, loc, targetPointerType, allocCall.getResult(0));
202 rewriter.replaceOp(allocOp, castOp);
207struct ConvertCopy final :
public OpConversionPattern<memref::CopyOp> {
208 using OpConversionPattern::OpConversionPattern;
211 matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
212 ConversionPatternRewriter &rewriter)
const override {
213 Location loc = copyOp.getLoc();
214 MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType());
215 MemRefType targetMemrefType =
216 cast<MemRefType>(copyOp.getTarget().getType());
219 return rewriter.notifyMatchFailure(
220 loc,
"incompatible source memref type for EmitC conversion");
223 return rewriter.notifyMatchFailure(
224 loc,
"incompatible target memref type for EmitC conversion");
227 cast<TypedValue<emitc::ArrayType>>(operands.getSource());
228 emitc::AddressOfOp srcPtr =
229 createPointerFromEmitcArray(loc, rewriter, srcArrayValue);
231 auto targetArrayValue =
232 cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
233 emitc::AddressOfOp targetPtr =
234 createPointerFromEmitcArray(loc, rewriter, targetArrayValue);
236 emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create(
239 targetPtr.getResult(), srcPtr.getResult(),
240 calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
242 rewriter.replaceOp(copyOp, memCpyCall.getResults());
248struct ConvertGlobal final :
public OpConversionPattern<memref::GlobalOp> {
249 using OpConversionPattern::OpConversionPattern;
252 matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
253 ConversionPatternRewriter &rewriter)
const override {
254 MemRefType opTy = op.getType();
255 if (!op.getType().hasStaticShape()) {
256 return rewriter.notifyMatchFailure(
257 op.getLoc(),
"cannot transform global with dynamic shape");
260 if (op.getAlignment().value_or(1) > 1) {
262 return rewriter.notifyMatchFailure(
263 op.getLoc(),
"global variable with alignment requirement is "
264 "currently not supported");
267 Type resultTy = convertMemRefType(opTy, getTypeConverter());
270 return rewriter.notifyMatchFailure(op.getLoc(),
271 "cannot convert result type");
275 if (visibility != SymbolTable::Visibility::Public &&
276 visibility != SymbolTable::Visibility::Private) {
277 return rewriter.notifyMatchFailure(
279 "only public and private visibility is currently supported");
283 bool staticSpecifier = visibility == SymbolTable::Visibility::Private;
284 bool externSpecifier = !staticSpecifier;
286 Attribute initialValue = operands.getInitialValueAttr();
287 if (opTy.getRank() == 0) {
289 if (std::optional<Attribute> initValueAttr = op.getInitialValue()) {
290 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(*initValueAttr)) {
291 initialValue = elementsAttr.getSplatValue<Attribute>();
295 if (isa_and_present<UnitAttr>(initialValue))
298 rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
299 op, operands.getSymName(), resultTy, initialValue, externSpecifier,
300 staticSpecifier, operands.getConstant());
305struct ConvertGetGlobal final
306 :
public OpConversionPattern<memref::GetGlobalOp> {
307 using OpConversionPattern::OpConversionPattern;
310 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
311 ConversionPatternRewriter &rewriter)
const override {
313 MemRefType opTy = op.getType();
314 Type resultTy = convertMemRefType(opTy, getTypeConverter());
317 return rewriter.notifyMatchFailure(op.getLoc(),
318 "cannot convert result type");
321 if (opTy.getRank() == 0) {
322 emitc::LValueType lvalueType = emitc::LValueType::get(resultTy);
323 emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create(
324 rewriter, op.getLoc(), lvalueType, operands.getNameAttr());
325 emitc::PointerType pointerType = emitc::PointerType::get(resultTy);
326 rewriter.replaceOpWithNewOp<emitc::AddressOfOp>(op, pointerType,
330 rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
331 operands.getNameAttr());
336struct ConvertLoad final :
public OpConversionPattern<memref::LoadOp> {
337 using OpConversionPattern::OpConversionPattern;
341 ConversionPatternRewriter &rewriter)
const override {
343 auto resultTy = getTypeConverter()->convertType(op.getType());
345 return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert type");
349 dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
351 return rewriter.notifyMatchFailure(op.getLoc(),
"expected array type");
354 auto subscript = emitc::SubscriptOp::create(
355 rewriter, op.getLoc(), arrayValue, operands.getIndices());
357 rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
362struct ConvertStore final :
public OpConversionPattern<memref::StoreOp> {
363 using OpConversionPattern::OpConversionPattern;
367 ConversionPatternRewriter &rewriter)
const override {
369 dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
371 return rewriter.notifyMatchFailure(op.getLoc(),
"expected array type");
374 auto subscript = emitc::SubscriptOp::create(
375 rewriter, op.getLoc(), arrayValue, operands.getIndices());
376 rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
377 operands.getValue());
384 typeConverter.addConversion(
385 [&](MemRefType memRefType) -> std::optional<Type> {
389 Type convertedElementType =
390 typeConverter.convertType(memRefType.getElementType());
391 if (!convertedElementType)
393 return emitc::ArrayType::get(memRefType.getShape(),
394 convertedElementType);
397 auto materializeAsUnrealizedCast = [](
OpBuilder &builder,
Type resultType,
400 if (inputs.size() != 1)
403 return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
407 typeConverter.addSourceMaterialization(materializeAsUnrealizedCast);
408 typeConverter.addTargetMaterialization(materializeAsUnrealizedCast);
413 patterns.
add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertGlobal,
static bool isMemRefTypeLegalForEmitC(MemRefType memRefType)
constexpr const char * mallocFunctionName
constexpr const char * alignedAllocFunctionName
IntegerAttr getIndexAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ConvertToEmitCPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
Visibility
An enumeration detailing the different visibility types that a symbol may have.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, const TypeConverter &converter)
void registerConvertMemRefToEmitCInterface(DialectRegistry ®istry)
LogicalResult matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override