28#include "llvm/Support/LogicalResult.h"
31#define DEBUG_TYPE "wasm-convert"
34#define GEN_PASS_DEF_RAISEWASMMLIR
35#include "mlir/Conversion/Passes.h.inc"
42template <
typename SourceOp,
typename TargetIntOp,
typename TargetFPOp>
43struct IntFPDispatchMappingConversion : OpConversionPattern<SourceOp> {
44 using OpConversionPattern<SourceOp>::OpConversionPattern;
47 matchAndRewrite(SourceOp srcOp,
typename SourceOp::Adaptor adaptor,
48 ConversionPatternRewriter &rewriter)
const override {
49 Type type = srcOp.getRhs().getType();
51 rewriter.replaceOpWithNewOp<TargetIntOp>(srcOp, srcOp->getResultTypes(),
52 adaptor.getOperands());
57 rewriter.replaceOpWithNewOp<TargetFPOp>(srcOp, srcOp->getResultTypes(),
58 adaptor.getOperands());
63using WasmAddOpConversion =
64 IntFPDispatchMappingConversion<AddOp, arith::AddIOp, arith::AddFOp>;
65using WasmMulOpConversion =
66 IntFPDispatchMappingConversion<MulOp, arith::MulIOp, arith::MulFOp>;
67using WasmSubOpConversion =
68 IntFPDispatchMappingConversion<SubOp, arith::SubIOp, arith::SubFOp>;
72template <
typename SourceOp,
typename TargetOp>
73struct OpMappingConversion : OpConversionPattern<SourceOp> {
74 using OpConversionPattern<SourceOp>::OpConversionPattern;
77 matchAndRewrite(SourceOp srcOp,
typename SourceOp::Adaptor adaptor,
78 ConversionPatternRewriter &rewriter)
const override {
79 rewriter.replaceOpWithNewOp<TargetOp>(srcOp, srcOp->getResultTypes(),
80 adaptor.getOperands());
85using WasmAndOpConversion = OpMappingConversion<AndOp, arith::AndIOp>;
86using WasmCeilOpConversion = OpMappingConversion<CeilOp, math::CeilOp>;
89using WasmConvertSOpConversion =
90 OpMappingConversion<ConvertSOp, arith::SIToFPOp>;
91using WasmConvertUOpConversion =
92 OpMappingConversion<ConvertUOp, arith::UIToFPOp>;
93using WasmDemoteOpConversion = OpMappingConversion<DemoteOp, arith::TruncFOp>;
94using WasmDivFPOpConversion = OpMappingConversion<DivOp, arith::DivFOp>;
95using WasmDivSIOpConversion = OpMappingConversion<DivSIOp, arith::DivSIOp>;
96using WasmDivUIOpConversion = OpMappingConversion<DivUIOp, arith::DivUIOp>;
97using WasmExtendSOpConversion =
98 OpMappingConversion<ExtendSI32Op, arith::ExtSIOp>;
99using WasmExtendUOpConversion =
100 OpMappingConversion<ExtendUI32Op, arith::ExtUIOp>;
101using WasmFloorOpConversion = OpMappingConversion<FloorOp, math::FloorOp>;
102using WasmMaxOpConversion = OpMappingConversion<MaxOp, arith::MaximumFOp>;
103using WasmMinOpConversion = OpMappingConversion<MinOp, arith::MinimumFOp>;
104using WasmOrOpConversion = OpMappingConversion<OrOp, arith::OrIOp>;
105using WasmPromoteOpConversion = OpMappingConversion<PromoteOp, arith::ExtFOp>;
106using WasmRemSIOpConversion = OpMappingConversion<RemSIOp, arith::RemSIOp>;
107using WasmRemUIOpConversion = OpMappingConversion<RemUIOp, arith::RemUIOp>;
108using WasmReinterpretOpConversion =
109 OpMappingConversion<ReinterpretOp, arith::BitcastOp>;
110using WasmShLOpConversion = OpMappingConversion<ShLOp, arith::ShLIOp>;
111using WasmShRSOpConversion = OpMappingConversion<ShRSOp, arith::ShRSIOp>;
112using WasmShRUOpConversion = OpMappingConversion<ShRUOp, arith::ShRUIOp>;
113using WasmXOrOpConversion = OpMappingConversion<XOrOp, arith::XOrIOp>;
114using WasmNegOpConversion = OpMappingConversion<NegOp, arith::NegFOp>;
115using WasmCopySignOpConversion =
116 OpMappingConversion<CopySignOp, math::CopySignOp>;
117using WasmClzOpConversion =
118 OpMappingConversion<ClzOp, math::CountLeadingZerosOp>;
119using WasmCtzOpConversion =
120 OpMappingConversion<CtzOp, math::CountTrailingZerosOp>;
121using WasmPopCntOpConversion = OpMappingConversion<PopCntOp, math::CtPopOp>;
122using WasmAbsOpConversion = OpMappingConversion<AbsOp, math::AbsFOp>;
123using WasmTruncOpConversion = OpMappingConversion<TruncOp, math::TruncOp>;
124using WasmSqrtOpConversion = OpMappingConversion<SqrtOp, math::SqrtOp>;
125using WasmWrapOpConversion = OpMappingConversion<WrapOp, arith::TruncIOp>;
127struct WasmCallOpConversion : OpConversionPattern<FuncCallOp> {
128 using OpConversionPattern::OpConversionPattern;
131 matchAndRewrite(FuncCallOp funcCallOp, FuncCallOp::Adaptor adaptor,
132 ConversionPatternRewriter &rewriter)
const override {
133 rewriter.replaceOpWithNewOp<func::CallOp>(
134 funcCallOp, funcCallOp.getCallee(), funcCallOp.getResults().getTypes(),
135 funcCallOp.getOperands());
140struct WasmConstOpConversion : OpConversionPattern<ConstOp> {
141 using OpConversionPattern::OpConversionPattern;
144 matchAndRewrite(ConstOp constOp, ConstOp::Adaptor adaptor,
145 ConversionPatternRewriter &rewriter)
const override {
146 rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, constOp.getValue());
151struct WasmFuncImportOpConversion : OpConversionPattern<FuncImportOp> {
152 using OpConversionPattern::OpConversionPattern;
155 matchAndRewrite(FuncImportOp funcImportOp, FuncImportOp::Adaptor,
156 ConversionPatternRewriter &rewriter)
const override {
157 auto nFunc = rewriter.replaceOpWithNewOp<func::FuncOp>(
158 funcImportOp, funcImportOp.getSymName(), funcImportOp.getType());
159 nFunc.setVisibility(SymbolTable::Visibility::Private);
164struct WasmFuncOpConversion : OpConversionPattern<FuncOp> {
165 using OpConversionPattern::OpConversionPattern;
168 matchAndRewrite(FuncOp funcOp, FuncOp::Adaptor adaptor,
169 ConversionPatternRewriter &rewriter)
const override {
171 func::FuncOp::create(rewriter, funcOp->getLoc(), funcOp.getSymName(),
172 funcOp.getFunctionType());
173 rewriter.cloneRegionBefore(funcOp.getBody(), newFunc.getBody(),
174 newFunc.getBody().end());
175 Block *oldEntryBlock = &newFunc.getBody().front();
177 TypeConverter::SignatureConversion sC{oldEntryBlock->
getNumArguments()};
178 auto numArgs = blockArgTypes.size();
179 for (
size_t i = 0; i < numArgs; ++i) {
180 auto argType = dyn_cast<LocalRefType>(blockArgTypes[i]);
183 sC.addInputs(i, argType.getElementType());
186 rewriter.applySignatureConversion(oldEntryBlock, sC, getTypeConverter());
187 rewriter.replaceOp(funcOp, newFunc);
192struct WasmGlobalImportOpConverter : OpConversionPattern<GlobalImportOp> {
193 using OpConversionPattern::OpConversionPattern;
195 matchAndRewrite(GlobalImportOp gIOp, GlobalImportOp::Adaptor adaptor,
196 ConversionPatternRewriter &rewriter)
const override {
197 auto memrefGOp = rewriter.replaceOpWithNewOp<memref::GlobalOp>(
198 gIOp, gIOp.getSymNameAttr(), rewriter.getStringAttr(
"nested"),
199 TypeAttr::get(MemRefType::get({1}, gIOp.getType())), Attribute{},
202 memrefGOp.setConstant(!gIOp.getIsMutable());
207template <
typename CRTP,
typename OriginOpType>
208struct GlobalOpConverter : OpConversionPattern<GlobalOp> {
209 using OpConversionPattern::OpConversionPattern;
211 matchAndRewrite(GlobalOp globalOp, GlobalOp::Adaptor adaptor,
212 ConversionPatternRewriter &rewriter)
const override {
213 ReturnOp rop = globalOp.getInitTerminator();
215 if (rop->getNumOperands() != 1)
216 return rewriter.notifyMatchFailure(
217 globalOp,
"globalOp initializer should return one value exactly");
220 dyn_cast<OriginOpType>(rop->getOperand(0).getDefiningOp());
223 return rewriter.notifyMatchFailure(
224 globalOp,
"invalid initializer op type for this pattern");
226 return static_cast<CRTP
const *
>(
this)->handleInitializer(
227 globalOp, rewriter, initializerOp);
231struct WasmGlobalWithConstInitConversion
232 : GlobalOpConverter<WasmGlobalWithConstInitConversion, ConstOp> {
233 using GlobalOpConverter::GlobalOpConverter;
234 LogicalResult handleInitializer(GlobalOp globalOp,
235 ConversionPatternRewriter &rewriter,
236 ConstOp constInit)
const {
239 ArrayRef<Attribute>{constInit.getValueAttr()});
240 auto globalReplacement = rewriter.replaceOpWithNewOp<memref::GlobalOp>(
241 globalOp, globalOp.getSymNameAttr(), rewriter.getStringAttr(
"private"),
242 TypeAttr::get(MemRefType::get({1}, globalOp.getType())), initializer,
245 globalReplacement.setConstant(!globalOp.getIsMutable());
250struct WasmGlobalWithGetGlobalInitConversion
251 : GlobalOpConverter<WasmGlobalWithGetGlobalInitConversion, GlobalGetOp> {
252 using GlobalOpConverter::GlobalOpConverter;
253 LogicalResult handleInitializer(GlobalOp globalOp,
254 ConversionPatternRewriter &rewriter,
255 GlobalGetOp constInit)
const {
256 auto globalReplacement = rewriter.replaceOpWithNewOp<memref::GlobalOp>(
257 globalOp, globalOp.getSymNameAttr(), rewriter.getStringAttr(
"private"),
258 TypeAttr::get(MemRefType::get({1}, globalOp.getType())),
259 rewriter.getUnitAttr(),
262 globalReplacement.setConstant(!globalOp.getIsMutable());
263 auto loc = globalOp.getLoc();
264 auto initializerName = (globalOp.getSymName() +
"::initializer").str();
265 auto globalInitializer =
266 func::FuncOp::create(rewriter, loc, initializerName,
268 globalInitializer->setAttr(rewriter.getStringAttr(
"initializer"),
269 rewriter.getUnitAttr());
270 auto *initializerBody = globalInitializer.addEntryBlock();
271 auto sip = rewriter.saveInsertionPoint();
272 rewriter.setInsertionPointToStart(initializerBody);
273 auto srcGlobalPtr = memref::GetGlobalOp::create(
274 rewriter, loc, MemRefType::get({1}, constInit.getType()),
275 constInit.getGlobal());
277 memref::GetGlobalOp::create(rewriter, loc, globalReplacement.getType(),
278 globalReplacement.getSymName());
281 memref::LoadOp::create(rewriter, loc, srcGlobalPtr,
ValueRange{idx});
282 memref::StoreOp::create(rewriter, loc, loadSrc.getResult(),
284 func::ReturnOp::create(rewriter, loc);
285 rewriter.restoreInsertionPoint(sip);
290inline TypedAttr getInitializerAttr(
Type t) {
292 "This helper is intended to use with int and float types");
294 return IntegerAttr::get(t, 0);
296 return FloatAttr::get(t, 0.);
300struct WasmLocalConversion : OpConversionPattern<LocalOp> {
301 using OpConversionPattern::OpConversionPattern;
303 matchAndRewrite(LocalOp localOp, LocalOp::Adaptor adaptor,
304 ConversionPatternRewriter &rewriter)
const override {
305 auto alloca = rewriter.replaceOpWithNewOp<memref::AllocaOp>(
308 auto initializer = arith::ConstantOp::create(
309 rewriter, localOp->getLoc(),
310 getInitializerAttr(localOp.getResult().getType().getElementType()));
311 memref::StoreOp::create(rewriter, localOp->getLoc(),
312 initializer.getResult(), alloca.getResult());
317struct WasmLocalGetConversion : OpConversionPattern<LocalGetOp> {
318 using OpConversionPattern::OpConversionPattern;
320 matchAndRewrite(LocalGetOp localGetOp, LocalGetOp::Adaptor adaptor,
321 ConversionPatternRewriter &rewriter)
const override {
322 rewriter.replaceOpWithNewOp<memref::LoadOp>(
323 localGetOp, localGetOp.getResult().
getType(), adaptor.getLocalVar(),
329struct WasmLocalSetConversion : OpConversionPattern<LocalSetOp> {
330 using OpConversionPattern::OpConversionPattern;
332 matchAndRewrite(LocalSetOp localSetOp, LocalSetOp::Adaptor adaptor,
333 ConversionPatternRewriter &rewriter)
const override {
334 rewriter.replaceOpWithNewOp<memref::StoreOp>(
335 localSetOp, adaptor.getValue(), adaptor.getLocalVar(),
ValueRange{});
340struct WasmLocalTeeConversion : OpConversionPattern<LocalTeeOp> {
341 using OpConversionPattern::OpConversionPattern;
343 matchAndRewrite(LocalTeeOp localTeeOp, LocalTeeOp::Adaptor adaptor,
344 ConversionPatternRewriter &rewriter)
const override {
345 memref::StoreOp::create(rewriter, localTeeOp->getLoc(), adaptor.getValue(),
346 adaptor.getLocalVar());
347 rewriter.replaceOp(localTeeOp, adaptor.getValue());
352struct WasmReturnOpConversion : OpConversionPattern<ReturnOp> {
353 using OpConversionPattern::OpConversionPattern;
356 matchAndRewrite(ReturnOp returnOp, ReturnOp::Adaptor adaptor,
357 ConversionPatternRewriter &rewriter)
const override {
358 rewriter.replaceOpWithNewOp<func::ReturnOp>(returnOp,
359 adaptor.getOperands());
365 void runOnOperation()
override {
367 target.addIllegalDialect<WasmSSADialect>();
368 target.addLegalDialect<arith::ArithDialect, BuiltinDialect,
369 cf::ControlFlowDialect, func::FuncDialect,
370 memref::MemRefDialect, math::MathDialect>();
373 tc.addConversion([](Type type) -> std::optional<Type> {
return type; });
374 tc.addConversion([](LocalRefType type) -> std::optional<Type> {
375 return MemRefType::get({}, type.getElementType());
377 tc.addTargetMaterialization([](OpBuilder &builder, MemRefType destType,
379 if (values.size() != 1 ||
380 values.front().
getType() != destType.getElementType())
382 auto localVar = memref::AllocaOp::create(builder, loc, destType);
383 memref::StoreOp::create(builder, loc, values.front(),
384 localVar.getResult());
385 return localVar.getResult();
389 llvm::DenseMap<StringAttr, StringAttr> idxSymToImportSym{};
390 auto *topOp = getOperation();
391 topOp->walk([&idxSymToImportSym,
this](ImportOpInterface importOp) {
392 auto const qualifiedImportName = importOp.getQualifiedImportName();
393 auto qualNameAttr = StringAttr::get(&
getContext(), qualifiedImportName);
394 idxSymToImportSym.insert(
395 std::make_pair(importOp.getSymbolName(), qualNameAttr));
398 if (
failed(applyFullConversion(topOp,
target, std::move(patterns))))
399 return signalPassFailure();
401 auto symTable = SymbolTable{topOp};
402 for (
auto &[oldName, newName] : idxSymToImportSym) {
403 if (
failed(symTable.rename(oldName, newName)))
404 return signalPassFailure();
420 WasmCallOpConversion,
421 WasmCeilOpConversion,
423 WasmConstOpConversion,
424 WasmConvertSOpConversion,
425 WasmConvertUOpConversion,
426 WasmCopySignOpConversion,
428 WasmDemoteOpConversion,
429 WasmDivFPOpConversion,
430 WasmDivSIOpConversion,
431 WasmDivUIOpConversion,
432 WasmExtendSOpConversion,
433 WasmExtendUOpConversion,
434 WasmFloorOpConversion,
435 WasmFuncImportOpConversion,
436 WasmFuncOpConversion,
437 WasmGlobalImportOpConverter,
438 WasmGlobalWithConstInitConversion,
439 WasmGlobalWithGetGlobalInitConversion,
441 WasmLocalGetConversion,
442 WasmLocalSetConversion,
443 WasmLocalTeeConversion,
449 WasmPopCntOpConversion,
450 WasmPromoteOpConversion,
451 WasmReinterpretOpConversion,
452 WasmRemSIOpConversion,
453 WasmRemUIOpConversion,
454 WasmReturnOpConversion,
456 WasmShRSOpConversion,
457 WasmShRUOpConversion,
458 WasmSqrtOpConversion,
460 WasmTruncOpConversion,
461 WasmWrapOpConversion,
468 return std::make_unique<RaiseWasmMLIRPass>();
static Type getElementType(Type type)
Determine the element type of type.
std::unique_ptr< Pass > createRaiseWasmMLIRPass()
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
type_range getType() const
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Include the generated interface declarations.
void populateRaiseWasmMLIRConversionPatterns(TypeConverter &, RewritePatternSet &)
Collect a set of patterns to convert from the Wasm dialect to standard dialects.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.