MLIR 23.0.0git
EmulateWideInt.cpp
Go to the documentation of this file.
1//===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17#include "llvm/Support/FormatVariadic.h"
18#include "llvm/Support/MathExtras.h"
19#include <cassert>
20
21namespace mlir::memref {
22#define GEN_PASS_DEF_MEMREFEMULATEWIDEINT
23#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
24} // namespace mlir::memref
25
26using namespace mlir;
27
28namespace {
29
30//===----------------------------------------------------------------------===//
31// ConvertMemRefAlloc
32//===----------------------------------------------------------------------===//
33
34struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
35 using OpConversionPattern::OpConversionPattern;
36
37 LogicalResult
38 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
39 ConversionPatternRewriter &rewriter) const override {
40 Type newTy = getTypeConverter()->convertType(op.getType());
41 if (!newTy)
42 return rewriter.notifyMatchFailure(
43 op->getLoc(),
44 llvm::formatv("failed to convert memref type: {0}", op.getType()));
45
46 rewriter.replaceOpWithNewOp<memref::AllocOp>(
47 op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(),
48 adaptor.getAlignmentAttr());
49 return success();
50 }
51};
52
53//===----------------------------------------------------------------------===//
54// ConvertMemRefLoad
55//===----------------------------------------------------------------------===//
56
57struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
58 using OpConversionPattern::OpConversionPattern;
59
60 LogicalResult
61 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
62 ConversionPatternRewriter &rewriter) const override {
63 Type newResTy = getTypeConverter()->convertType(op.getType());
64 if (!newResTy)
65 return rewriter.notifyMatchFailure(
66 op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
67 op.getMemRefType()));
68
69 rewriter.replaceOpWithNewOp<memref::LoadOp>(
70 op, newResTy, adaptor.getMemref(), adaptor.getIndices(),
71 op.getNontemporal());
72 return success();
73 }
74};
75
76//===----------------------------------------------------------------------===//
77// ConvertMemRefStore
78//===----------------------------------------------------------------------===//
79
80struct ConvertMemRefStore final : OpConversionPattern<memref::StoreOp> {
81 using OpConversionPattern::OpConversionPattern;
82
83 LogicalResult
84 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
85 ConversionPatternRewriter &rewriter) const override {
86 Type newTy = getTypeConverter()->convertType(op.getMemRefType());
87 if (!newTy)
88 return rewriter.notifyMatchFailure(
89 op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
90 op.getMemRefType()));
91
92 rewriter.replaceOpWithNewOp<memref::StoreOp>(
93 op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices(),
94 op.getNontemporal());
95 return success();
96 }
97};
98
99//===----------------------------------------------------------------------===//
100// Pass Definition
101//===----------------------------------------------------------------------===//
102
103struct EmulateWideIntPass final
104 : memref::impl::MemRefEmulateWideIntBase<EmulateWideIntPass> {
105 using MemRefEmulateWideIntBase::MemRefEmulateWideIntBase;
106
107 void runOnOperation() override {
108 if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
109 signalPassFailure();
110 return;
111 }
112
113 Operation *op = getOperation();
114 MLIRContext *ctx = op->getContext();
115
116 arith::WideIntEmulationConverter typeConverter(widestIntSupported);
118 ConversionTarget target(*ctx);
119 target.addDynamicallyLegalDialect<
120 arith::ArithDialect, memref::MemRefDialect, vector::VectorDialect>(
121 [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
122
123 RewritePatternSet patterns(ctx);
124 // Add common patterns to support contants, functions, etc.
125 arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
126
127 memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns);
128
129 if (failed(applyPartialConversion(op, target, std::move(patterns))))
130 signalPassFailure();
131 }
132};
133
134} // end anonymous namespace
135
136//===----------------------------------------------------------------------===//
137// Public Interface Definition
138//===----------------------------------------------------------------------===//
139
141 const arith::WideIntEmulationConverter &typeConverter,
142 RewritePatternSet &patterns) {
143 // Populate `memref.*` conversion patterns.
144 patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefStore>(
145 typeConverter, patterns.getContext());
146}
147
149 arith::WideIntEmulationConverter &typeConverter) {
150 typeConverter.addConversion(
151 [&typeConverter](MemRefType ty) -> std::optional<Type> {
152 auto intTy = dyn_cast<IntegerType>(ty.getElementType());
153 if (!intTy)
154 return ty;
155
156 if (intTy.getIntOrFloatBitWidth() <=
157 typeConverter.getMaxTargetIntBitWidth())
158 return ty;
159
160 Type newElemTy = typeConverter.convertType(intTy);
161 if (!newElemTy)
162 return nullptr;
163
164 return ty.cloneWith(std::nullopt, newElemTy);
165 });
166}
return success()
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:237
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Converts integer types that are too wide for the target by splitting them in two halves and thus turn...
void populateMemRefWideIntEmulationPatterns(const arith::WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating wide integer memref operations with ops over narrower integer types.
void populateMemRefWideIntEmulationConversions(arith::WideIntEmulationConverter &typeConverter)
Appends type conversions for emulating wide integer memref operations with ops over narrowe integer t...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.