MLIR 23.0.0git
RaiseWasmMLIR.cpp
Go to the documentation of this file.
1//===- RaiseWasmMLIR.cpp - Convert Wasm to less abstract dialects ---*- C++
2//-*-===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This file implements lowering of wasm operations to standard dialects ops.
11//
12//===----------------------------------------------------------------------===//
13
15
25#include "mlir/IR/ValueRange.h"
28#include "llvm/Support/LogicalResult.h"
29#include <optional>
30
31#define DEBUG_TYPE "wasm-convert"
32
33namespace mlir {
34#define GEN_PASS_DEF_RAISEWASMMLIR
35#include "mlir/Conversion/Passes.h.inc"
36} // namespace mlir
37
38using namespace mlir;
39using namespace mlir::wasmssa;
40namespace {
41
42template <typename SourceOp, typename TargetIntOp, typename TargetFPOp>
43struct IntFPDispatchMappingConversion : OpConversionPattern<SourceOp> {
44 using OpConversionPattern<SourceOp>::OpConversionPattern;
45
46 LogicalResult
47 matchAndRewrite(SourceOp srcOp, typename SourceOp::Adaptor adaptor,
48 ConversionPatternRewriter &rewriter) const override {
49 Type type = srcOp.getRhs().getType();
50 if (type.isInteger()) {
51 rewriter.replaceOpWithNewOp<TargetIntOp>(srcOp, srcOp->getResultTypes(),
52 adaptor.getOperands());
53 return success();
54 }
55 if (!type.isFloat())
56 return failure();
57 rewriter.replaceOpWithNewOp<TargetFPOp>(srcOp, srcOp->getResultTypes(),
58 adaptor.getOperands());
59 return success();
60 }
61};
62
63using WasmAddOpConversion =
64 IntFPDispatchMappingConversion<AddOp, arith::AddIOp, arith::AddFOp>;
65using WasmMulOpConversion =
66 IntFPDispatchMappingConversion<MulOp, arith::MulIOp, arith::MulFOp>;
67using WasmSubOpConversion =
68 IntFPDispatchMappingConversion<SubOp, arith::SubIOp, arith::SubFOp>;
69
70/// Convert a k-ary source operation \p SourceOp into an operation \p TargetOp.
71/// Both \p SourceOp and \p TargetOp must have the same number of operands.
72template <typename SourceOp, typename TargetOp>
73struct OpMappingConversion : OpConversionPattern<SourceOp> {
74 using OpConversionPattern<SourceOp>::OpConversionPattern;
75
76 LogicalResult
77 matchAndRewrite(SourceOp srcOp, typename SourceOp::Adaptor adaptor,
78 ConversionPatternRewriter &rewriter) const override {
79 rewriter.replaceOpWithNewOp<TargetOp>(srcOp, srcOp->getResultTypes(),
80 adaptor.getOperands());
81 return success();
82 }
83};
84
85using WasmAndOpConversion = OpMappingConversion<AndOp, arith::AndIOp>;
86using WasmCeilOpConversion = OpMappingConversion<CeilOp, math::CeilOp>;
87/// TODO: SIToFP and UIToFP don't allow specification of the floating point
88/// rounding mode
89using WasmConvertSOpConversion =
90 OpMappingConversion<ConvertSOp, arith::SIToFPOp>;
91using WasmConvertUOpConversion =
92 OpMappingConversion<ConvertUOp, arith::UIToFPOp>;
93using WasmDemoteOpConversion = OpMappingConversion<DemoteOp, arith::TruncFOp>;
94using WasmDivFPOpConversion = OpMappingConversion<DivOp, arith::DivFOp>;
95using WasmDivSIOpConversion = OpMappingConversion<DivSIOp, arith::DivSIOp>;
96using WasmDivUIOpConversion = OpMappingConversion<DivUIOp, arith::DivUIOp>;
97using WasmExtendSOpConversion =
98 OpMappingConversion<ExtendSI32Op, arith::ExtSIOp>;
99using WasmExtendUOpConversion =
100 OpMappingConversion<ExtendUI32Op, arith::ExtUIOp>;
101using WasmFloorOpConversion = OpMappingConversion<FloorOp, math::FloorOp>;
102using WasmMaxOpConversion = OpMappingConversion<MaxOp, arith::MaximumFOp>;
103using WasmMinOpConversion = OpMappingConversion<MinOp, arith::MinimumFOp>;
104using WasmOrOpConversion = OpMappingConversion<OrOp, arith::OrIOp>;
105using WasmPromoteOpConversion = OpMappingConversion<PromoteOp, arith::ExtFOp>;
106using WasmRemSIOpConversion = OpMappingConversion<RemSIOp, arith::RemSIOp>;
107using WasmRemUIOpConversion = OpMappingConversion<RemUIOp, arith::RemUIOp>;
108using WasmReinterpretOpConversion =
109 OpMappingConversion<ReinterpretOp, arith::BitcastOp>;
110using WasmShLOpConversion = OpMappingConversion<ShLOp, arith::ShLIOp>;
111using WasmShRSOpConversion = OpMappingConversion<ShRSOp, arith::ShRSIOp>;
112using WasmShRUOpConversion = OpMappingConversion<ShRUOp, arith::ShRUIOp>;
113using WasmXOrOpConversion = OpMappingConversion<XOrOp, arith::XOrIOp>;
114using WasmNegOpConversion = OpMappingConversion<NegOp, arith::NegFOp>;
115using WasmCopySignOpConversion =
116 OpMappingConversion<CopySignOp, math::CopySignOp>;
117using WasmClzOpConversion =
118 OpMappingConversion<ClzOp, math::CountLeadingZerosOp>;
119using WasmCtzOpConversion =
120 OpMappingConversion<CtzOp, math::CountTrailingZerosOp>;
121using WasmPopCntOpConversion = OpMappingConversion<PopCntOp, math::CtPopOp>;
122using WasmAbsOpConversion = OpMappingConversion<AbsOp, math::AbsFOp>;
123using WasmTruncOpConversion = OpMappingConversion<TruncOp, math::TruncOp>;
124using WasmSqrtOpConversion = OpMappingConversion<SqrtOp, math::SqrtOp>;
125using WasmWrapOpConversion = OpMappingConversion<WrapOp, arith::TruncIOp>;
126
127struct WasmCallOpConversion : OpConversionPattern<FuncCallOp> {
128 using OpConversionPattern::OpConversionPattern;
129
130 LogicalResult
131 matchAndRewrite(FuncCallOp funcCallOp, FuncCallOp::Adaptor adaptor,
132 ConversionPatternRewriter &rewriter) const override {
133 rewriter.replaceOpWithNewOp<func::CallOp>(
134 funcCallOp, funcCallOp.getCallee(), funcCallOp.getResults().getTypes(),
135 funcCallOp.getOperands());
136 return success();
137 }
138};
139
140struct WasmConstOpConversion : OpConversionPattern<ConstOp> {
141 using OpConversionPattern::OpConversionPattern;
142
143 LogicalResult
144 matchAndRewrite(ConstOp constOp, ConstOp::Adaptor adaptor,
145 ConversionPatternRewriter &rewriter) const override {
146 rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, constOp.getValue());
147 return success();
148 }
149};
150
151struct WasmFuncImportOpConversion : OpConversionPattern<FuncImportOp> {
152 using OpConversionPattern::OpConversionPattern;
153
154 LogicalResult
155 matchAndRewrite(FuncImportOp funcImportOp, FuncImportOp::Adaptor,
156 ConversionPatternRewriter &rewriter) const override {
157 auto nFunc = rewriter.replaceOpWithNewOp<func::FuncOp>(
158 funcImportOp, funcImportOp.getSymName(), funcImportOp.getType());
159 nFunc.setVisibility(SymbolTable::Visibility::Private);
160 return success();
161 }
162};
163
164struct WasmFuncOpConversion : OpConversionPattern<FuncOp> {
165 using OpConversionPattern::OpConversionPattern;
166
167 LogicalResult
168 matchAndRewrite(FuncOp funcOp, FuncOp::Adaptor adaptor,
169 ConversionPatternRewriter &rewriter) const override {
170 auto newFunc =
171 func::FuncOp::create(rewriter, funcOp->getLoc(), funcOp.getSymName(),
172 funcOp.getFunctionType());
173 rewriter.cloneRegionBefore(funcOp.getBody(), newFunc.getBody(),
174 newFunc.getBody().end());
175 Block *oldEntryBlock = &newFunc.getBody().front();
176 auto blockArgTypes = oldEntryBlock->getArgumentTypes();
177 TypeConverter::SignatureConversion sC{oldEntryBlock->getNumArguments()};
178 auto numArgs = blockArgTypes.size();
179 for (size_t i = 0; i < numArgs; ++i) {
180 auto argType = dyn_cast<LocalRefType>(blockArgTypes[i]);
181 if (!argType)
182 return failure();
183 sC.addInputs(i, argType.getElementType());
184 }
185
186 rewriter.applySignatureConversion(oldEntryBlock, sC, getTypeConverter());
187 rewriter.replaceOp(funcOp, newFunc);
188 return success();
189 }
190};
191
192struct WasmGlobalImportOpConverter : OpConversionPattern<GlobalImportOp> {
193 using OpConversionPattern::OpConversionPattern;
194 LogicalResult
195 matchAndRewrite(GlobalImportOp gIOp, GlobalImportOp::Adaptor adaptor,
196 ConversionPatternRewriter &rewriter) const override {
197 auto memrefGOp = rewriter.replaceOpWithNewOp<memref::GlobalOp>(
198 gIOp, gIOp.getSymNameAttr(), rewriter.getStringAttr("nested"),
199 TypeAttr::get(MemRefType::get({1}, gIOp.getType())), Attribute{},
200 /*constant*/ UnitAttr{},
201 /*alignment*/ IntegerAttr{});
202 memrefGOp.setConstant(!gIOp.getIsMutable());
203 return success();
204 }
205};
206
207template <typename CRTP, typename OriginOpType>
208struct GlobalOpConverter : OpConversionPattern<GlobalOp> {
209 using OpConversionPattern::OpConversionPattern;
210 LogicalResult
211 matchAndRewrite(GlobalOp globalOp, GlobalOp::Adaptor adaptor,
212 ConversionPatternRewriter &rewriter) const override {
213 ReturnOp rop = globalOp.getInitTerminator();
214
215 if (rop->getNumOperands() != 1)
216 return rewriter.notifyMatchFailure(
217 globalOp, "globalOp initializer should return one value exactly");
218
219 auto initializerOp =
220 dyn_cast<OriginOpType>(rop->getOperand(0).getDefiningOp());
221
222 if (!initializerOp)
223 return rewriter.notifyMatchFailure(
224 globalOp, "invalid initializer op type for this pattern");
225
226 return static_cast<CRTP const *>(this)->handleInitializer(
227 globalOp, rewriter, initializerOp);
228 }
229};
230
231struct WasmGlobalWithConstInitConversion
232 : GlobalOpConverter<WasmGlobalWithConstInitConversion, ConstOp> {
233 using GlobalOpConverter::GlobalOpConverter;
234 LogicalResult handleInitializer(GlobalOp globalOp,
235 ConversionPatternRewriter &rewriter,
236 ConstOp constInit) const {
237 auto initializer =
238 DenseElementsAttr::get(RankedTensorType::get({1}, globalOp.getType()),
239 ArrayRef<Attribute>{constInit.getValueAttr()});
240 auto globalReplacement = rewriter.replaceOpWithNewOp<memref::GlobalOp>(
241 globalOp, globalOp.getSymNameAttr(), rewriter.getStringAttr("private"),
242 TypeAttr::get(MemRefType::get({1}, globalOp.getType())), initializer,
243 /*constant*/ UnitAttr{},
244 /*alignment*/ IntegerAttr{});
245 globalReplacement.setConstant(!globalOp.getIsMutable());
246 return success();
247 }
248};
249
250struct WasmGlobalWithGetGlobalInitConversion
251 : GlobalOpConverter<WasmGlobalWithGetGlobalInitConversion, GlobalGetOp> {
252 using GlobalOpConverter::GlobalOpConverter;
253 LogicalResult handleInitializer(GlobalOp globalOp,
254 ConversionPatternRewriter &rewriter,
255 GlobalGetOp constInit) const {
256 auto globalReplacement = rewriter.replaceOpWithNewOp<memref::GlobalOp>(
257 globalOp, globalOp.getSymNameAttr(), rewriter.getStringAttr("private"),
258 TypeAttr::get(MemRefType::get({1}, globalOp.getType())),
259 rewriter.getUnitAttr(),
260 /*constant*/ UnitAttr{},
261 /*alignment*/ IntegerAttr{});
262 globalReplacement.setConstant(!globalOp.getIsMutable());
263 auto loc = globalOp.getLoc();
264 auto initializerName = (globalOp.getSymName() + "::initializer").str();
265 auto globalInitializer =
266 func::FuncOp::create(rewriter, loc, initializerName,
267 FunctionType::get(getContext(), {}, {}));
268 globalInitializer->setAttr(rewriter.getStringAttr("initializer"),
269 rewriter.getUnitAttr());
270 auto *initializerBody = globalInitializer.addEntryBlock();
271 auto sip = rewriter.saveInsertionPoint();
272 rewriter.setInsertionPointToStart(initializerBody);
273 auto srcGlobalPtr = memref::GetGlobalOp::create(
274 rewriter, loc, MemRefType::get({1}, constInit.getType()),
275 constInit.getGlobal());
276 auto destGlobalPtr =
277 memref::GetGlobalOp::create(rewriter, loc, globalReplacement.getType(),
278 globalReplacement.getSymName());
279 auto idx = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
280 auto loadSrc =
281 memref::LoadOp::create(rewriter, loc, srcGlobalPtr, ValueRange{idx});
282 memref::StoreOp::create(rewriter, loc, loadSrc.getResult(),
283 destGlobalPtr.getResult(), ValueRange{idx});
284 func::ReturnOp::create(rewriter, loc);
285 rewriter.restoreInsertionPoint(sip);
286 return success();
287 }
288};
289
290inline TypedAttr getInitializerAttr(Type t) {
291 assert(t.isIntOrFloat() &&
292 "This helper is intended to use with int and float types");
293 if (t.isInteger())
294 return IntegerAttr::get(t, 0);
295 if (t.isFloat())
296 return FloatAttr::get(t, 0.);
297 return TypedAttr{};
298}
299
300struct WasmLocalConversion : OpConversionPattern<LocalOp> {
301 using OpConversionPattern::OpConversionPattern;
302 LogicalResult
303 matchAndRewrite(LocalOp localOp, LocalOp::Adaptor adaptor,
304 ConversionPatternRewriter &rewriter) const override {
305 auto alloca = rewriter.replaceOpWithNewOp<memref::AllocaOp>(
306 localOp,
307 MemRefType::get({}, localOp.getResult().getType().getElementType()));
308 auto initializer = arith::ConstantOp::create(
309 rewriter, localOp->getLoc(),
310 getInitializerAttr(localOp.getResult().getType().getElementType()));
311 memref::StoreOp::create(rewriter, localOp->getLoc(),
312 initializer.getResult(), alloca.getResult());
313 return success();
314 }
315};
316
317struct WasmLocalGetConversion : OpConversionPattern<LocalGetOp> {
318 using OpConversionPattern::OpConversionPattern;
319 LogicalResult
320 matchAndRewrite(LocalGetOp localGetOp, LocalGetOp::Adaptor adaptor,
321 ConversionPatternRewriter &rewriter) const override {
322 rewriter.replaceOpWithNewOp<memref::LoadOp>(
323 localGetOp, localGetOp.getResult().getType(), adaptor.getLocalVar(),
324 ValueRange{});
325 return success();
326 }
327};
328
329struct WasmLocalSetConversion : OpConversionPattern<LocalSetOp> {
330 using OpConversionPattern::OpConversionPattern;
331 LogicalResult
332 matchAndRewrite(LocalSetOp localSetOp, LocalSetOp::Adaptor adaptor,
333 ConversionPatternRewriter &rewriter) const override {
334 rewriter.replaceOpWithNewOp<memref::StoreOp>(
335 localSetOp, adaptor.getValue(), adaptor.getLocalVar(), ValueRange{});
336 return success();
337 }
338};
339
340struct WasmLocalTeeConversion : OpConversionPattern<LocalTeeOp> {
341 using OpConversionPattern::OpConversionPattern;
342 LogicalResult
343 matchAndRewrite(LocalTeeOp localTeeOp, LocalTeeOp::Adaptor adaptor,
344 ConversionPatternRewriter &rewriter) const override {
345 memref::StoreOp::create(rewriter, localTeeOp->getLoc(), adaptor.getValue(),
346 adaptor.getLocalVar());
347 rewriter.replaceOp(localTeeOp, adaptor.getValue());
348 return success();
349 }
350};
351
352struct WasmReturnOpConversion : OpConversionPattern<ReturnOp> {
353 using OpConversionPattern::OpConversionPattern;
354
355 LogicalResult
356 matchAndRewrite(ReturnOp returnOp, ReturnOp::Adaptor adaptor,
357 ConversionPatternRewriter &rewriter) const override {
358 rewriter.replaceOpWithNewOp<func::ReturnOp>(returnOp,
359 adaptor.getOperands());
360 return success();
361 }
362};
363
364struct RaiseWasmMLIRPass : public impl::RaiseWasmMLIRBase<RaiseWasmMLIRPass> {
365 void runOnOperation() override {
366 ConversionTarget target{getContext()};
367 target.addIllegalDialect<WasmSSADialect>();
368 target.addLegalDialect<arith::ArithDialect, BuiltinDialect,
369 cf::ControlFlowDialect, func::FuncDialect,
370 memref::MemRefDialect, math::MathDialect>();
371 RewritePatternSet patterns(&getContext());
372 TypeConverter tc{};
373 tc.addConversion([](Type type) -> std::optional<Type> { return type; });
374 tc.addConversion([](LocalRefType type) -> std::optional<Type> {
375 return MemRefType::get({}, type.getElementType());
376 });
377 tc.addTargetMaterialization([](OpBuilder &builder, MemRefType destType,
378 ValueRange values, Location loc) -> Value {
379 if (values.size() != 1 ||
380 values.front().getType() != destType.getElementType())
381 return {};
382 auto localVar = memref::AllocaOp::create(builder, loc, destType);
383 memref::StoreOp::create(builder, loc, values.front(),
384 localVar.getResult());
385 return localVar.getResult();
386 });
388
389 llvm::DenseMap<StringAttr, StringAttr> idxSymToImportSym{};
390 auto *topOp = getOperation();
391 topOp->walk([&idxSymToImportSym, this](ImportOpInterface importOp) {
392 auto const qualifiedImportName = importOp.getQualifiedImportName();
393 auto qualNameAttr = StringAttr::get(&getContext(), qualifiedImportName);
394 idxSymToImportSym.insert(
395 std::make_pair(importOp.getSymbolName(), qualNameAttr));
396 });
397
398 if (failed(applyFullConversion(topOp, target, std::move(patterns))))
399 return signalPassFailure();
400
401 auto symTable = SymbolTable{topOp};
402 for (auto &[oldName, newName] : idxSymToImportSym) {
403 if (failed(symTable.rename(oldName, newName)))
404 return signalPassFailure();
405 }
406 }
407};
408} // namespace
409
411 TypeConverter &tc, RewritePatternSet &patternSet) {
412 auto *ctx = patternSet.getContext();
413 // Disable clang-format in patternSet for readability + small diffs.
414 // clang-format off
415 patternSet
416 .add<
417 WasmAbsOpConversion,
418 WasmAddOpConversion,
419 WasmAndOpConversion,
420 WasmCallOpConversion,
421 WasmCeilOpConversion,
422 WasmClzOpConversion,
423 WasmConstOpConversion,
424 WasmConvertSOpConversion,
425 WasmConvertUOpConversion,
426 WasmCopySignOpConversion,
427 WasmCtzOpConversion,
428 WasmDemoteOpConversion,
429 WasmDivFPOpConversion,
430 WasmDivSIOpConversion,
431 WasmDivUIOpConversion,
432 WasmExtendSOpConversion,
433 WasmExtendUOpConversion,
434 WasmFloorOpConversion,
435 WasmFuncImportOpConversion,
436 WasmFuncOpConversion,
437 WasmGlobalImportOpConverter,
438 WasmGlobalWithConstInitConversion,
439 WasmGlobalWithGetGlobalInitConversion,
440 WasmLocalConversion,
441 WasmLocalGetConversion,
442 WasmLocalSetConversion,
443 WasmLocalTeeConversion,
444 WasmMaxOpConversion,
445 WasmMinOpConversion,
446 WasmMulOpConversion,
447 WasmNegOpConversion,
448 WasmOrOpConversion,
449 WasmPopCntOpConversion,
450 WasmPromoteOpConversion,
451 WasmReinterpretOpConversion,
452 WasmRemSIOpConversion,
453 WasmRemUIOpConversion,
454 WasmReturnOpConversion,
455 WasmShLOpConversion,
456 WasmShRSOpConversion,
457 WasmShRUOpConversion,
458 WasmSqrtOpConversion,
459 WasmSubOpConversion,
460 WasmTruncOpConversion,
461 WasmWrapOpConversion,
462 WasmXOrOpConversion
463 >(tc, ctx);
464 // clang-format on
465}
466
467std::unique_ptr<Pass> createRaiseWasmMLIRPass() {
468 return std::make_unique<RaiseWasmMLIRPass>();
469}
return success()
static Type getElementType(Type type)
Determine the element type of type.
b getContext())
std::unique_ptr< Pass > createRaiseWasmMLIRPass()
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
unsigned getNumArguments()
Definition Block.h:138
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition Types.cpp:47
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
type_range getType() const
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
void populateRaiseWasmMLIRConversionPatterns(TypeConverter &, RewritePatternSet &)
Collect a set of patterns to convert from the Wasm dialect to standard dialects.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307