25 #include "llvm/ADT/STLExtras.h"
32 return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
33 memRefType.getRank() != 0 &&
34 !llvm::is_contained(memRefType.getShape(), 0);
44 void populateConvertToEmitCConversionPatterns(
55 dialect->addInterfaces<MemRefToEmitCDialectInterface>();
68 matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
71 if (!op.getType().hasStaticShape()) {
73 op.getLoc(),
"cannot transform alloca with dynamic shape");
76 if (op.getAlignment().value_or(1) > 1) {
80 op.getLoc(),
"cannot transform alloca with alignment requirement");
83 auto resultTy = getTypeConverter()->convertType(op.getType());
95 if (opTy.getRank() == 0) {
103 static Value calculateMemrefTotalSizeBytes(
Location loc, MemRefType memrefType,
106 "incompatible memref type for EmitC conversion");
107 emitc::CallOpaqueOp elementSize = emitc::CallOpaqueOp::create(
111 {TypeAttr::get(memrefType.getElementType())}));
114 int64_t numElements = llvm::product_of(memrefType.getShape());
115 emitc::ConstantOp numElementsValue = emitc::ConstantOp::create(
116 builder, loc, indexType, builder.
getIndexAttr(numElements));
119 emitc::MulOp totalSizeBytes = emitc::MulOp::create(
120 builder, loc, sizeTType, elementSize.getResult(0), numElementsValue);
122 return totalSizeBytes.getResult();
125 static emitc::ApplyOp
129 emitc::ConstantOp zeroIndex = emitc::ConstantOp::create(
132 emitc::ArrayType arrayType = arrayValue.getType();
134 emitc::SubscriptOp subPtr =
135 emitc::SubscriptOp::create(builder, loc, arrayValue,
ValueRange(indices));
136 emitc::ApplyOp ptr = emitc::ApplyOp::create(
146 matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
149 MemRefType memrefType = allocOp.getType();
152 loc,
"incompatible memref type for EmitC conversion");
156 Type elementType = memrefType.getElementType();
158 emitc::CallOpaqueOp sizeofElementOp = emitc::CallOpaqueOp::create(
163 int64_t numElements = 1;
164 for (int64_t dimSize : memrefType.getShape()) {
165 numElements *= dimSize;
167 Value numElementsValue = emitc::ConstantOp::create(
168 rewriter, loc, indexType, rewriter.
getIndexAttr(numElements));
170 Value totalSizeBytes =
171 emitc::MulOp::create(rewriter, loc, sizeTType,
172 sizeofElementOp.getResult(0), numElementsValue);
174 emitc::CallOpaqueOp allocCall;
175 StringAttr allocFunctionName;
176 Value alignmentValue;
178 if (allocOp.getAlignment()) {
180 alignmentValue = emitc::ConstantOp::create(
181 rewriter, loc, sizeTType,
183 allocOp.getAlignment().value_or(0)));
184 argsVec.push_back(alignmentValue);
189 argsVec.push_back(totalSizeBytes);
192 allocCall = emitc::CallOpaqueOp::create(
196 allocFunctionName, args);
199 emitc::CastOp castOp = emitc::CastOp::create(
200 rewriter, loc, targetPointerType, allocCall.getResult(0));
211 matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
214 MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType());
215 MemRefType targetMemrefType =
216 cast<MemRefType>(copyOp.getTarget().getType());
220 loc,
"incompatible source memref type for EmitC conversion");
224 loc,
"incompatible target memref type for EmitC conversion");
227 cast<TypedValue<emitc::ArrayType>>(operands.getSource());
228 emitc::ApplyOp srcPtr =
229 createPointerFromEmitcArray(loc, rewriter, srcArrayValue);
231 auto targetArrayValue =
232 cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
233 emitc::ApplyOp targetPtr =
234 createPointerFromEmitcArray(loc, rewriter, targetArrayValue);
236 emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create(
239 targetPtr.getResult(), srcPtr.getResult(),
240 calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
242 rewriter.
replaceOp(copyOp, memCpyCall.getResults());
252 matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
254 MemRefType opTy = op.getType();
255 if (!op.getType().hasStaticShape()) {
257 op.getLoc(),
"cannot transform global with dynamic shape");
260 if (op.getAlignment().value_or(1) > 1) {
263 op.getLoc(),
"global variable with alignment requirement is "
264 "currently not supported");
267 Type resultTy = convertMemRefType(opTy, getTypeConverter());
271 "cannot convert result type");
279 "only public and private visibility is currently supported");
284 bool externSpecifier = !staticSpecifier;
286 Attribute initialValue = operands.getInitialValueAttr();
287 if (opTy.getRank() == 0) {
288 auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
289 initialValue = elementsAttr.getSplatValue<
Attribute>();
291 if (isa_and_present<UnitAttr>(initialValue))
295 op, operands.getSymName(), resultTy, initialValue, externSpecifier,
296 staticSpecifier, operands.getConstant());
301 struct ConvertGetGlobal final
306 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
309 MemRefType opTy = op.getType();
310 Type resultTy = convertMemRefType(opTy, getTypeConverter());
314 "cannot convert result type");
317 if (opTy.getRank() == 0) {
319 emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create(
320 rewriter, op.getLoc(), lvalueType, operands.getNameAttr());
327 operands.getNameAttr());
336 matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
339 auto resultTy = getTypeConverter()->convertType(op.getType());
345 dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
350 auto subscript = emitc::SubscriptOp::create(
351 rewriter, op.getLoc(), arrayValue, operands.getIndices());
362 matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
365 dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
370 auto subscript = emitc::SubscriptOp::create(
371 rewriter, op.getLoc(), arrayValue, operands.getIndices());
373 operands.getValue());
381 [&](MemRefType memRefType) -> std::optional<Type> {
385 Type convertedElementType =
386 typeConverter.
convertType(memRefType.getElementType());
387 if (!convertedElementType)
390 convertedElementType);
393 auto materializeAsUnrealizedCast = [](
OpBuilder &builder,
Type resultType,
396 if (inputs.size() != 1)
399 return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
409 patterns.add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertGlobal,
static MLIRContext * getContext(OpFoldResult val)
static bool isMemRefTypeLegalForEmitC(MemRefType memRefType)
constexpr const char * mallocFunctionName
constexpr const char * alignedAllocFunctionName
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
ConvertToEmitCPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
Visibility
An enumeration detailing the different visibility types that a symbol may have.
@ Public
The symbol is public and may be referenced anywhere internal or external to the visible references in...
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, const TypeConverter &converter)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConvertMemRefToEmitCInterface(DialectRegistry ®istry)