MLIR  19.0.0git
EmulateWideInt.cpp
Go to the documentation of this file.
1 //===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/Support/MathExtras.h"
19 #include <cassert>
20 
21 namespace mlir::memref {
22 #define GEN_PASS_DEF_MEMREFEMULATEWIDEINT
23 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
24 } // namespace mlir::memref
25 
26 using namespace mlir;
27 
28 namespace {
29 
30 //===----------------------------------------------------------------------===//
31 // ConvertMemRefAlloc
32 //===----------------------------------------------------------------------===//
33 
34 struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
36 
38  matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
39  ConversionPatternRewriter &rewriter) const override {
40  Type newTy = getTypeConverter()->convertType(op.getType());
41  if (!newTy)
42  return rewriter.notifyMatchFailure(
43  op->getLoc(),
44  llvm::formatv("failed to convert memref type: {0}", op.getType()));
45 
46  rewriter.replaceOpWithNewOp<memref::AllocOp>(
47  op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(),
48  adaptor.getAlignmentAttr());
49  return success();
50  }
51 };
52 
53 //===----------------------------------------------------------------------===//
54 // ConvertMemRefLoad
55 //===----------------------------------------------------------------------===//
56 
57 struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
59 
61  matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
62  ConversionPatternRewriter &rewriter) const override {
63  Type newResTy = getTypeConverter()->convertType(op.getType());
64  if (!newResTy)
65  return rewriter.notifyMatchFailure(
66  op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
67  op.getMemRefType()));
68 
69  rewriter.replaceOpWithNewOp<memref::LoadOp>(
70  op, newResTy, adaptor.getMemref(), adaptor.getIndices(),
71  op.getNontemporal());
72  return success();
73  }
74 };
75 
76 //===----------------------------------------------------------------------===//
77 // ConvertMemRefStore
78 //===----------------------------------------------------------------------===//
79 
80 struct ConvertMemRefStore final : OpConversionPattern<memref::StoreOp> {
82 
84  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
85  ConversionPatternRewriter &rewriter) const override {
86  Type newTy = getTypeConverter()->convertType(op.getMemRefType());
87  if (!newTy)
88  return rewriter.notifyMatchFailure(
89  op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
90  op.getMemRefType()));
91 
92  rewriter.replaceOpWithNewOp<memref::StoreOp>(
93  op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices(),
94  op.getNontemporal());
95  return success();
96  }
97 };
98 
99 //===----------------------------------------------------------------------===//
100 // Pass Definition
101 //===----------------------------------------------------------------------===//
102 
103 struct EmulateWideIntPass final
104  : memref::impl::MemRefEmulateWideIntBase<EmulateWideIntPass> {
105  using MemRefEmulateWideIntBase::MemRefEmulateWideIntBase;
106 
107  void runOnOperation() override {
108  if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
109  signalPassFailure();
110  return;
111  }
112 
113  Operation *op = getOperation();
114  MLIRContext *ctx = op->getContext();
115 
116  arith::WideIntEmulationConverter typeConverter(widestIntSupported);
118  ConversionTarget target(*ctx);
119  target.addDynamicallyLegalDialect<
120  arith::ArithDialect, memref::MemRefDialect, vector::VectorDialect>(
121  [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
122 
123  RewritePatternSet patterns(ctx);
124  // Add common pattenrs to support contants, functions, etc.
125  arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
126 
127  memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns);
128 
129  if (failed(applyPartialConversion(op, target, std::move(patterns))))
130  signalPassFailure();
131  }
132 };
133 
134 } // end anonymous namespace
135 
136 //===----------------------------------------------------------------------===//
137 // Public Interface Definition
138 //===----------------------------------------------------------------------===//
139 
141  arith::WideIntEmulationConverter &typeConverter,
142  RewritePatternSet &patterns) {
143  // Populate `memref.*` conversion patterns.
144  patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefStore>(
145  typeConverter, patterns.getContext());
146 }
147 
149  arith::WideIntEmulationConverter &typeConverter) {
150  typeConverter.addConversion(
151  [&typeConverter](MemRefType ty) -> std::optional<Type> {
152  auto intTy = dyn_cast<IntegerType>(ty.getElementType());
153  if (!intTy)
154  return ty;
155 
156  if (intTy.getIntOrFloatBitWidth() <=
157  typeConverter.getMaxTargetIntBitWidth())
158  return ty;
159 
160  Type newElemTy = typeConverter.convertType(intTy);
161  if (!newElemTy)
162  return std::nullopt;
163 
164  return ty.cloneWith(std::nullopt, newElemTy);
165  });
166 }
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Converts integer types that are too wide for the target by splitting them in two halves and thus turn...
void populateArithWideIntEmulationPatterns(WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns)
Adds patterns to emulate wide Arith and Function ops over integer types into supported ones.
void populateMemRefWideIntEmulationPatterns(arith::WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating wide integer memref operations with ops over narrower integer types.
void populateMemRefWideIntEmulationConversions(arith::WideIntEmulationConverter &typeConverter)
Appends type conversions for emulating wide integer memref operations with ops over narrowe integer t...
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26