MLIR  18.0.0git
EmulateNarrowType.cpp
Go to the documentation of this file.
1 //===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
2 //-*-===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/OpDefinition.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/MathExtras.h"
28 #include <cassert>
29 #include <type_traits>
30 
31 using namespace mlir;
32 
33 //===----------------------------------------------------------------------===//
34 // Utility functions
35 //===----------------------------------------------------------------------===//
36 
37 /// Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted
38 /// type. The result MemRefType of the old op must have a rank and stride of 1,
39 /// with static offset and size. The number of bits in the offset must evenly
40 /// divide the bitwidth of the new converted type.
41 template <typename MemRefOpTy>
43  typename MemRefOpTy::Adaptor adaptor,
44  MemRefOpTy op, MemRefType newTy) {
45  static_assert(std::is_same<MemRefOpTy, memref::SubViewOp>() ||
46  std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(),
47  "Expected only memref::SubViewOp or memref::ReinterpretCastOp");
48 
49  auto convertedElementType = newTy.getElementType();
50  auto oldElementType = op.getType().getElementType();
51  int srcBits = oldElementType.getIntOrFloatBitWidth();
52  int dstBits = convertedElementType.getIntOrFloatBitWidth();
53  if (dstBits % srcBits != 0) {
54  return rewriter.notifyMatchFailure(op,
55  "only dstBits % srcBits == 0 supported");
56  }
57 
58  // Only support stride of 1.
59  if (llvm::any_of(op.getStaticStrides(),
60  [](int64_t stride) { return stride != 1; })) {
61  return rewriter.notifyMatchFailure(op->getLoc(),
62  "stride != 1 is not supported");
63  }
64 
65  auto sizes = op.getStaticSizes();
66  int64_t offset = op.getStaticOffset(0);
67  // Only support static sizes and offsets.
68  if (llvm::any_of(sizes,
69  [](int64_t size) { return size == ShapedType::kDynamic; }) ||
70  offset == ShapedType::kDynamic) {
71  return rewriter.notifyMatchFailure(
72  op->getLoc(), "dynamic size or offset is not supported");
73  }
74 
75  int elementsPerByte = dstBits / srcBits;
76  if (offset % elementsPerByte != 0) {
77  return rewriter.notifyMatchFailure(
78  op->getLoc(), "offset not multiple of elementsPerByte is not "
79  "supported");
80  }
81 
83  if (sizes.size())
84  size.push_back(ceilDiv(sizes[0], elementsPerByte));
85  offset = offset / elementsPerByte;
86 
87  rewriter.replaceOpWithNewOp<MemRefOpTy>(op, newTy,
88  *adaptor.getODSOperands(0).begin(),
89  offset, size, op.getStaticStrides());
90  return success();
91 }
92 
93 /// When data is loaded/stored in `targetBits` granularity, but is used in
94 /// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
95 /// treated as an array of elements of width `sourceBits`.
96 /// Return the bit offset of the value at position `srcIdx`. For example, if
97 /// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
98 /// located at (x % 2) * 4. Because there are two elements in one i8, and one
99 /// element has 4 bits.
101  int sourceBits, int targetBits,
102  OpBuilder &builder) {
103  assert(targetBits % sourceBits == 0);
104  AffineExpr s0;
105  bindSymbols(builder.getContext(), s0);
106  int scaleFactor = targetBits / sourceBits;
107  AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
108  OpFoldResult offsetVal =
109  affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
110  Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
111  IntegerType dstType = builder.getIntegerType(targetBits);
112  return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
113 }
114 
115 /// When writing a subbyte size, masked bitwise operations are used to only
116 /// modify the relevant bits. This function returns an and mask for clearing
117 /// the destination bits in a subbyte write. E.g., when writing to the second
118 /// i4 in an i32, 0xFFFFFF0F is created.
119 static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
120  int64_t srcBits, int64_t dstBits,
121  Value bitwidthOffset, OpBuilder &builder) {
122  auto dstIntegerType = builder.getIntegerType(dstBits);
123  auto maskRightAlignedAttr =
124  builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
125  Value maskRightAligned = builder.create<arith::ConstantOp>(
126  loc, dstIntegerType, maskRightAlignedAttr);
127  Value writeMaskInverse =
128  builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
129  auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
130  Value flipVal =
131  builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr);
132  return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
133 }
134 
135 /// Returns the scaled linearized index based on the `srcBits` and `dstBits`
136 /// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and
137 /// the returned index has the granularity of `dstBits`
139  OpFoldResult linearizedIndex,
140  int64_t srcBits, int64_t dstBits) {
141  AffineExpr s0;
142  bindSymbols(builder.getContext(), s0);
143  int64_t scaler = dstBits / srcBits;
144  OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
145  builder, loc, s0.floorDiv(scaler), {linearizedIndex});
146  return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
147 }
148 
149 static OpFoldResult
150 getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
151  const SmallVector<OpFoldResult> &indices,
152  Value memref) {
153  auto stridedMetadata =
154  builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
155  OpFoldResult linearizedIndices;
156  std::tie(std::ignore, linearizedIndices) =
158  builder, loc, srcBits, srcBits,
159  stridedMetadata.getConstifiedMixedOffset(),
160  stridedMetadata.getConstifiedMixedSizes(),
161  stridedMetadata.getConstifiedMixedStrides(), indices);
162  return linearizedIndices;
163 }
164 
165 namespace {
166 
167 //===----------------------------------------------------------------------===//
168 // ConvertMemRefAllocation
169 //===----------------------------------------------------------------------===//
170 
171 template <typename OpTy>
172 struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
174 
176  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
177  ConversionPatternRewriter &rewriter) const override {
178  static_assert(std::is_same<OpTy, memref::AllocOp>() ||
179  std::is_same<OpTy, memref::AllocaOp>(),
180  "expected only memref::AllocOp or memref::AllocaOp");
181  auto currentType = cast<MemRefType>(op.getMemref().getType());
182  auto newResultType = dyn_cast<MemRefType>(
183  this->getTypeConverter()->convertType(op.getType()));
184  if (!newResultType) {
185  return rewriter.notifyMatchFailure(
186  op->getLoc(),
187  llvm::formatv("failed to convert memref type: {0}", op.getType()));
188  }
189 
190  // Special case zero-rank memrefs.
191  if (currentType.getRank() == 0) {
192  rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
193  adaptor.getSymbolOperands(),
194  adaptor.getAlignmentAttr());
195  return success();
196  }
197 
198  Location loc = op.getLoc();
199  OpFoldResult zero = rewriter.getIndexAttr(0);
200  SmallVector<OpFoldResult> indices(currentType.getRank(), zero);
201 
202  // Get linearized type.
203  int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
204  int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
205  SmallVector<OpFoldResult> sizes = op.getMixedSizes();
206 
207  memref::LinearizedMemRefInfo linearizedMemRefInfo =
209  rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes);
210  SmallVector<Value> dynamicLinearizedSize;
211  if (!newResultType.hasStaticShape()) {
212  dynamicLinearizedSize.push_back(getValueOrCreateConstantIndexOp(
213  rewriter, loc, linearizedMemRefInfo.linearizedSize));
214  }
215 
216  rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
217  adaptor.getSymbolOperands(),
218  adaptor.getAlignmentAttr());
219  return success();
220  }
221 };
222 
223 //===----------------------------------------------------------------------===//
224 // ConvertMemRefAssumeAlignment
225 //===----------------------------------------------------------------------===//
226 
227 struct ConvertMemRefAssumeAlignment final
228  : OpConversionPattern<memref::AssumeAlignmentOp> {
230 
232  matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
233  ConversionPatternRewriter &rewriter) const override {
234  Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
235  if (!newTy) {
236  return rewriter.notifyMatchFailure(
237  op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
238  op.getMemref().getType()));
239  }
240 
241  rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
242  op, adaptor.getMemref(), adaptor.getAlignmentAttr());
243  return success();
244  }
245 };
246 
247 //===----------------------------------------------------------------------===//
248 // ConvertMemRefLoad
249 //===----------------------------------------------------------------------===//
250 
251 struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
253 
255  matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
256  ConversionPatternRewriter &rewriter) const override {
257  auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
258  auto convertedElementType = convertedType.getElementType();
259  auto oldElementType = op.getMemRefType().getElementType();
260  int srcBits = oldElementType.getIntOrFloatBitWidth();
261  int dstBits = convertedElementType.getIntOrFloatBitWidth();
262  if (dstBits % srcBits != 0) {
263  return rewriter.notifyMatchFailure(
264  op, "only dstBits % srcBits == 0 supported");
265  }
266 
267  Location loc = op.getLoc();
268  // Special case 0-rank memref loads.
269  Value bitsLoad;
270  if (convertedType.getRank() == 0) {
271  bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
272  ValueRange{});
273  } else {
274  // Linearize the indices of the original load instruction. Do not account
275  // for the scaling yet. This will be accounted for later.
276  OpFoldResult linearizedIndices = getLinearizedSrcIndices(
277  rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
278 
279  Value newLoad = rewriter.create<memref::LoadOp>(
280  loc, adaptor.getMemref(),
281  getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
282  dstBits));
283 
284  // Get the offset and shift the bits to the rightmost.
285  // Note, currently only the big-endian is supported.
286  Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices,
287  srcBits, dstBits, rewriter);
288  bitsLoad = rewriter.create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset);
289  }
290 
291  // Get the corresponding bits. If the arith computation bitwidth equals
292  // to the emulated bitwidth, we apply a mask to extract the low bits.
293  // It is not clear if this case actually happens in practice, but we keep
294  // the operations just in case. Otherwise, if the arith computation bitwidth
295  // is different from the emulated bitwidth we truncate the result.
296  Operation *result;
297  auto resultTy = getTypeConverter()->convertType(oldElementType);
298  if (resultTy == convertedElementType) {
299  auto mask = rewriter.create<arith::ConstantOp>(
300  loc, convertedElementType,
301  rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
302 
303  result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
304  } else {
305  result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
306  }
307 
308  rewriter.replaceOp(op, result->getResult(0));
309  return success();
310  }
311 };
312 
313 //===----------------------------------------------------------------------===//
314 // ConvertMemRefReinterpretCast
315 //===----------------------------------------------------------------------===//
316 
317 /// Output types should be at most one dimensional, so only the 0 or 1
318 /// dimensional cases are supported.
319 struct ConvertMemRefReinterpretCast final
320  : OpConversionPattern<memref::ReinterpretCastOp> {
322 
324  matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
325  ConversionPatternRewriter &rewriter) const override {
326  MemRefType newTy =
327  dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
328  if (!newTy) {
329  return rewriter.notifyMatchFailure(
330  op->getLoc(),
331  llvm::formatv("failed to convert memref type: {0}", op.getType()));
332  }
333 
334  // Only support for 0 or 1 dimensional cases.
335  if (op.getType().getRank() > 1) {
336  return rewriter.notifyMatchFailure(
337  op->getLoc(), "subview with rank > 1 is not supported");
338  }
339 
340  return convertCastingOp(rewriter, adaptor, op, newTy);
341  }
342 };
343 
344 //===----------------------------------------------------------------------===//
345 // ConvertMemrefStore
346 //===----------------------------------------------------------------------===//
347 
348 struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
350 
352  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
353  ConversionPatternRewriter &rewriter) const override {
354  auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
355  int srcBits = op.getMemRefType().getElementTypeBitWidth();
356  int dstBits = convertedType.getElementTypeBitWidth();
357  auto dstIntegerType = rewriter.getIntegerType(dstBits);
358  if (dstBits % srcBits != 0) {
359  return rewriter.notifyMatchFailure(
360  op, "only dstBits % srcBits == 0 supported");
361  }
362 
363  Location loc = op.getLoc();
364  Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
365  adaptor.getValue());
366 
367  // Special case 0-rank memref stores. No need for masking.
368  if (convertedType.getRank() == 0) {
369  rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign,
370  extendedInput, adaptor.getMemref(),
371  ValueRange{});
372  rewriter.eraseOp(op);
373  return success();
374  }
375 
376  OpFoldResult linearizedIndices = getLinearizedSrcIndices(
377  rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
378  Value storeIndices = getIndicesForLoadOrStore(
379  rewriter, loc, linearizedIndices, srcBits, dstBits);
380  Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
381  dstBits, rewriter);
382  Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits,
383  dstBits, bitwidthOffset, rewriter);
384  // Align the value to write with the destination bits
385  Value alignedVal =
386  rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
387 
388  // Clear destination bits
389  rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
390  writeMask, adaptor.getMemref(),
391  storeIndices);
392  // Write srcs bits to destination
393  rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
394  alignedVal, adaptor.getMemref(),
395  storeIndices);
396  rewriter.eraseOp(op);
397  return success();
398  }
399 };
400 
401 //===----------------------------------------------------------------------===//
402 // ConvertMemRefSubview
403 //===----------------------------------------------------------------------===//
404 
405 /// Emulating narrow ints on subview have limited support, supporting only
406 /// static offset and size and stride of 1. Ideally, the subview should be
407 /// folded away before running narrow type emulation, and this pattern would
408 /// never run. This pattern is mostly used for testing pruposes.
409 struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
411 
413  matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
414  ConversionPatternRewriter &rewriter) const override {
415  MemRefType newTy =
416  dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
417  if (!newTy) {
418  return rewriter.notifyMatchFailure(
419  op->getLoc(),
420  llvm::formatv("failed to convert memref type: {0}", op.getType()));
421  }
422 
423  // Only support offset for 1-D subview.
424  if (op.getType().getRank() != 1) {
425  return rewriter.notifyMatchFailure(
426  op->getLoc(), "subview with rank > 1 is not supported");
427  }
428 
429  return convertCastingOp(rewriter, adaptor, op, newTy);
430  }
431 };
432 
433 } // end anonymous namespace
434 
435 //===----------------------------------------------------------------------===//
436 // Public Interface Definition
437 //===----------------------------------------------------------------------===//
438 
441  RewritePatternSet &patterns) {
442 
443  // Populate `memref.*` conversion patterns.
444  patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
445  ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
446  ConvertMemrefStore, ConvertMemRefAssumeAlignment,
447  ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
448  typeConverter, patterns.getContext());
450 }
451 
452 static SmallVector<int64_t> getLinearizedShape(MemRefType ty, int srcBits,
453  int dstBits) {
454  if (ty.getRank() == 0)
455  return {};
456 
457  int64_t linearizedShape = 1;
458  for (auto shape : ty.getShape()) {
459  if (shape == ShapedType::kDynamic)
460  return {ShapedType::kDynamic};
461  linearizedShape *= shape;
462  }
463  int scale = dstBits / srcBits;
464  // Scale the size to the ceilDiv(linearizedShape, scale)
465  // to accomodate all the values.
466  linearizedShape = (linearizedShape + scale - 1) / scale;
467  return {linearizedShape};
468 }
469 
471  arith::NarrowTypeEmulationConverter &typeConverter) {
472  typeConverter.addConversion(
473  [&typeConverter](MemRefType ty) -> std::optional<Type> {
474  auto intTy = dyn_cast<IntegerType>(ty.getElementType());
475  if (!intTy)
476  return ty;
477 
478  unsigned width = intTy.getWidth();
479  unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
480  if (width >= loadStoreWidth)
481  return ty;
482 
483  // Currently only handle innermost stride being 1, checking
484  SmallVector<int64_t> strides;
485  int64_t offset;
486  if (failed(getStridesAndOffset(ty, strides, offset)))
487  return std::nullopt;
488  if (!strides.empty() && strides.back() != 1)
489  return std::nullopt;
490 
491  auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
492  intTy.getSignedness());
493  if (!newElemTy)
494  return std::nullopt;
495 
496  StridedLayoutAttr layoutAttr;
497  // If the offset is 0, we do not need a strided layout as the stride is
498  // 1, so we only use the strided layout if the offset is not 0.
499  if (offset != 0) {
500  if (offset == ShapedType::kDynamic) {
501  layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
502  ArrayRef<int64_t>{1});
503  } else {
504  // Check if the number of bytes are a multiple of the loadStoreWidth
505  // and if so, divide it by the loadStoreWidth to get the offset.
506  if ((offset * width) % loadStoreWidth != 0)
507  return std::nullopt;
508  offset = (offset * width) / loadStoreWidth;
509 
510  layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
511  ArrayRef<int64_t>{1});
512  }
513  }
514 
515  return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
516  newElemTy, layoutAttr, ty.getMemorySpace());
517  });
518 }
static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
When data is loaded/stored in targetBits granularity, but is used in sourceBits granularity (sourceBi...
static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc, OpFoldResult linearizedIndex, int64_t srcBits, int64_t dstBits)
Returns the scaled linearized index based on the srcBits and dstBits sizes.
static SmallVector< int64_t > getLinearizedShape(MemRefType ty, int srcBits, int dstBits)
static OpFoldResult getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits, const SmallVector< OpFoldResult > &indices, Value memref)
static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter, typename MemRefOpTy::Adaptor adaptor, MemRefOpTy op, MemRefType newTy)
Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted type.
static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices, int64_t srcBits, int64_t dstBits, Value bitwidthOffset, OpBuilder &builder)
When writing a subbyte size, masked bitwise operations are used to only modify the relevant bits.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:867
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1172
void populateMemRefNarrowTypeEmulationPatterns(arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating memref operations over narrow types with ops over wider types.
void populateMemRefNarrowTypeEmulationConversions(arith::NarrowTypeEmulationConverter &typeConverter)
Appends type conversions for emulating memref operations over narrow types with ops over wider types.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:50
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
Include the generated interface declarations.
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
Definition: MathExtras.h:23
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:348
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:40
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
For a memref with offset, sizes and strides, returns the offset and size to use for the linearized me...
Definition: MemRefUtils.h:42