MLIR  20.0.0git
EmulateNarrowType.cpp
Go to the documentation of this file.
1 //===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
2 //-*-===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/OpDefinition.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include "llvm/Support/MathExtras.h"
25 #include <cassert>
26 #include <type_traits>
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Utility functions
32 //===----------------------------------------------------------------------===//
33 
34 /// Converts a memref::ReinterpretCastOp to the converted type. The result
35 /// MemRefType of the old op must have a rank and stride of 1, with static
36 /// offset and size. The number of bits in the offset must evenly divide the
37 /// bitwidth of the new converted type.
38 static LogicalResult
40  memref::ReinterpretCastOp::Adaptor adaptor,
41  memref::ReinterpretCastOp op, MemRefType newTy) {
42  auto convertedElementType = newTy.getElementType();
43  auto oldElementType = op.getType().getElementType();
44  int srcBits = oldElementType.getIntOrFloatBitWidth();
45  int dstBits = convertedElementType.getIntOrFloatBitWidth();
46  if (dstBits % srcBits != 0) {
47  return rewriter.notifyMatchFailure(op,
48  "only dstBits % srcBits == 0 supported");
49  }
50 
51  // Only support stride of 1.
52  if (llvm::any_of(op.getStaticStrides(),
53  [](int64_t stride) { return stride != 1; })) {
54  return rewriter.notifyMatchFailure(op->getLoc(),
55  "stride != 1 is not supported");
56  }
57 
58  auto sizes = op.getStaticSizes();
59  int64_t offset = op.getStaticOffset(0);
60  // Only support static sizes and offsets.
61  if (llvm::any_of(sizes,
62  [](int64_t size) { return size == ShapedType::kDynamic; }) ||
63  offset == ShapedType::kDynamic) {
64  return rewriter.notifyMatchFailure(
65  op, "dynamic size or offset is not supported");
66  }
67 
68  int elementsPerByte = dstBits / srcBits;
69  if (offset % elementsPerByte != 0) {
70  return rewriter.notifyMatchFailure(
71  op, "offset not multiple of elementsPerByte is not supported");
72  }
73 
75  if (sizes.size())
76  size.push_back(llvm::divideCeilSigned(sizes[0], elementsPerByte));
77  offset = offset / elementsPerByte;
78 
79  rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
80  op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides());
81  return success();
82 }
83 
84 /// When data is loaded/stored in `targetBits` granularity, but is used in
85 /// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
86 /// treated as an array of elements of width `sourceBits`.
87 /// Return the bit offset of the value at position `srcIdx`. For example, if
88 /// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
89 /// located at (x % 2) * 4. Because there are two elements in one i8, and one
90 /// element has 4 bits.
92  int sourceBits, int targetBits,
93  OpBuilder &builder) {
94  assert(targetBits % sourceBits == 0);
95  AffineExpr s0;
96  bindSymbols(builder.getContext(), s0);
97  int scaleFactor = targetBits / sourceBits;
98  AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
99  OpFoldResult offsetVal =
100  affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
101  Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
102  IntegerType dstType = builder.getIntegerType(targetBits);
103  return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
104 }
105 
106 /// When writing a subbyte size, masked bitwise operations are used to only
107 /// modify the relevant bits. This function returns an and mask for clearing
108 /// the destination bits in a subbyte write. E.g., when writing to the second
109 /// i4 in an i32, 0xFFFFFF0F is created.
110 static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
111  int64_t srcBits, int64_t dstBits,
112  Value bitwidthOffset, OpBuilder &builder) {
113  auto dstIntegerType = builder.getIntegerType(dstBits);
114  auto maskRightAlignedAttr =
115  builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
116  Value maskRightAligned = builder.create<arith::ConstantOp>(
117  loc, dstIntegerType, maskRightAlignedAttr);
118  Value writeMaskInverse =
119  builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
120  auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
121  Value flipVal =
122  builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr);
123  return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
124 }
125 
126 /// Returns the scaled linearized index based on the `srcBits` and `dstBits`
127 /// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and
128 /// the returned index has the granularity of `dstBits`
130  OpFoldResult linearizedIndex,
131  int64_t srcBits, int64_t dstBits) {
132  AffineExpr s0;
133  bindSymbols(builder.getContext(), s0);
134  int64_t scaler = dstBits / srcBits;
135  OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
136  builder, loc, s0.floorDiv(scaler), {linearizedIndex});
137  return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
138 }
139 
140 static OpFoldResult
141 getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
142  const SmallVector<OpFoldResult> &indices,
143  Value memref) {
144  auto stridedMetadata =
145  builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
146  OpFoldResult linearizedIndices;
147  std::tie(std::ignore, linearizedIndices) =
149  builder, loc, srcBits, srcBits,
150  stridedMetadata.getConstifiedMixedOffset(),
151  stridedMetadata.getConstifiedMixedSizes(),
152  stridedMetadata.getConstifiedMixedStrides(), indices);
153  return linearizedIndices;
154 }
155 
156 namespace {
157 
158 //===----------------------------------------------------------------------===//
159 // ConvertMemRefAllocation
160 //===----------------------------------------------------------------------===//
161 
162 template <typename OpTy>
163 struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
165 
166  LogicalResult
167  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
168  ConversionPatternRewriter &rewriter) const override {
169  static_assert(std::is_same<OpTy, memref::AllocOp>() ||
170  std::is_same<OpTy, memref::AllocaOp>(),
171  "expected only memref::AllocOp or memref::AllocaOp");
172  auto currentType = cast<MemRefType>(op.getMemref().getType());
173  auto newResultType = dyn_cast<MemRefType>(
174  this->getTypeConverter()->convertType(op.getType()));
175  if (!newResultType) {
176  return rewriter.notifyMatchFailure(
177  op->getLoc(),
178  llvm::formatv("failed to convert memref type: {0}", op.getType()));
179  }
180 
181  // Special case zero-rank memrefs.
182  if (currentType.getRank() == 0) {
183  rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
184  adaptor.getSymbolOperands(),
185  adaptor.getAlignmentAttr());
186  return success();
187  }
188 
189  Location loc = op.getLoc();
190  OpFoldResult zero = rewriter.getIndexAttr(0);
191  SmallVector<OpFoldResult> indices(currentType.getRank(), zero);
192 
193  // Get linearized type.
194  int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
195  int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
196  SmallVector<OpFoldResult> sizes = op.getMixedSizes();
197 
198  memref::LinearizedMemRefInfo linearizedMemRefInfo =
200  rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes);
201  SmallVector<Value> dynamicLinearizedSize;
202  if (!newResultType.hasStaticShape()) {
203  dynamicLinearizedSize.push_back(getValueOrCreateConstantIndexOp(
204  rewriter, loc, linearizedMemRefInfo.linearizedSize));
205  }
206 
207  rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
208  adaptor.getSymbolOperands(),
209  adaptor.getAlignmentAttr());
210  return success();
211  }
212 };
213 
214 //===----------------------------------------------------------------------===//
215 // ConvertMemRefAssumeAlignment
216 //===----------------------------------------------------------------------===//
217 
218 struct ConvertMemRefAssumeAlignment final
219  : OpConversionPattern<memref::AssumeAlignmentOp> {
221 
222  LogicalResult
223  matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
224  ConversionPatternRewriter &rewriter) const override {
225  Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
226  if (!newTy) {
227  return rewriter.notifyMatchFailure(
228  op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
229  op.getMemref().getType()));
230  }
231 
232  rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
233  op, adaptor.getMemref(), adaptor.getAlignmentAttr());
234  return success();
235  }
236 };
237 
238 //===----------------------------------------------------------------------===//
239 // ConvertMemRefLoad
240 //===----------------------------------------------------------------------===//
241 
242 struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
244 
245  LogicalResult
246  matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
247  ConversionPatternRewriter &rewriter) const override {
248  auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
249  auto convertedElementType = convertedType.getElementType();
250  auto oldElementType = op.getMemRefType().getElementType();
251  int srcBits = oldElementType.getIntOrFloatBitWidth();
252  int dstBits = convertedElementType.getIntOrFloatBitWidth();
253  if (dstBits % srcBits != 0) {
254  return rewriter.notifyMatchFailure(
255  op, "only dstBits % srcBits == 0 supported");
256  }
257 
258  Location loc = op.getLoc();
259  // Special case 0-rank memref loads.
260  Value bitsLoad;
261  if (convertedType.getRank() == 0) {
262  bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
263  ValueRange{});
264  } else {
265  // Linearize the indices of the original load instruction. Do not account
266  // for the scaling yet. This will be accounted for later.
267  OpFoldResult linearizedIndices = getLinearizedSrcIndices(
268  rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
269 
270  Value newLoad = rewriter.create<memref::LoadOp>(
271  loc, adaptor.getMemref(),
272  getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
273  dstBits));
274 
275  // Get the offset and shift the bits to the rightmost.
276  // Note, currently only the big-endian is supported.
277  Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices,
278  srcBits, dstBits, rewriter);
279  bitsLoad = rewriter.create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset);
280  }
281 
282  // Get the corresponding bits. If the arith computation bitwidth equals
283  // to the emulated bitwidth, we apply a mask to extract the low bits.
284  // It is not clear if this case actually happens in practice, but we keep
285  // the operations just in case. Otherwise, if the arith computation bitwidth
286  // is different from the emulated bitwidth we truncate the result.
287  Operation *result;
288  auto resultTy = getTypeConverter()->convertType(oldElementType);
289  if (resultTy == convertedElementType) {
290  auto mask = rewriter.create<arith::ConstantOp>(
291  loc, convertedElementType,
292  rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
293 
294  result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
295  } else {
296  result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
297  }
298 
299  rewriter.replaceOp(op, result->getResult(0));
300  return success();
301  }
302 };
303 
304 //===----------------------------------------------------------------------===//
305 // ConvertMemRefReinterpretCast
306 //===----------------------------------------------------------------------===//
307 
308 /// Output types should be at most one dimensional, so only the 0 or 1
309 /// dimensional cases are supported.
310 struct ConvertMemRefReinterpretCast final
311  : OpConversionPattern<memref::ReinterpretCastOp> {
313 
314  LogicalResult
315  matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
316  ConversionPatternRewriter &rewriter) const override {
317  MemRefType newTy =
318  dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
319  if (!newTy) {
320  return rewriter.notifyMatchFailure(
321  op->getLoc(),
322  llvm::formatv("failed to convert memref type: {0}", op.getType()));
323  }
324 
325  // Only support for 0 or 1 dimensional cases.
326  if (op.getType().getRank() > 1) {
327  return rewriter.notifyMatchFailure(
328  op->getLoc(), "subview with rank > 1 is not supported");
329  }
330 
331  return convertCastingOp(rewriter, adaptor, op, newTy);
332  }
333 };
334 
335 //===----------------------------------------------------------------------===//
336 // ConvertMemrefStore
337 //===----------------------------------------------------------------------===//
338 
339 struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
341 
342  LogicalResult
343  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
344  ConversionPatternRewriter &rewriter) const override {
345  auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
346  int srcBits = op.getMemRefType().getElementTypeBitWidth();
347  int dstBits = convertedType.getElementTypeBitWidth();
348  auto dstIntegerType = rewriter.getIntegerType(dstBits);
349  if (dstBits % srcBits != 0) {
350  return rewriter.notifyMatchFailure(
351  op, "only dstBits % srcBits == 0 supported");
352  }
353 
354  Location loc = op.getLoc();
355  Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
356  adaptor.getValue());
357 
358  // Special case 0-rank memref stores. No need for masking.
359  if (convertedType.getRank() == 0) {
360  rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign,
361  extendedInput, adaptor.getMemref(),
362  ValueRange{});
363  rewriter.eraseOp(op);
364  return success();
365  }
366 
367  OpFoldResult linearizedIndices = getLinearizedSrcIndices(
368  rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
369  Value storeIndices = getIndicesForLoadOrStore(
370  rewriter, loc, linearizedIndices, srcBits, dstBits);
371  Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
372  dstBits, rewriter);
373  Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits,
374  dstBits, bitwidthOffset, rewriter);
375  // Align the value to write with the destination bits
376  Value alignedVal =
377  rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
378 
379  // Clear destination bits
380  rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
381  writeMask, adaptor.getMemref(),
382  storeIndices);
383  // Write srcs bits to destination
384  rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
385  alignedVal, adaptor.getMemref(),
386  storeIndices);
387  rewriter.eraseOp(op);
388  return success();
389  }
390 };
391 
392 //===----------------------------------------------------------------------===//
393 // ConvertMemRefSubview
394 //===----------------------------------------------------------------------===//
395 
396 /// Emulating narrow ints on subview have limited support, supporting only
397 /// static offset and size and stride of 1. Ideally, the subview should be
398 /// folded away before running narrow type emulation, and this pattern should
399 /// only run for cases that can't be folded.
400 struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
402 
403  LogicalResult
404  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
405  ConversionPatternRewriter &rewriter) const override {
406  MemRefType newTy = dyn_cast<MemRefType>(
407  getTypeConverter()->convertType(subViewOp.getType()));
408  if (!newTy) {
409  return rewriter.notifyMatchFailure(
410  subViewOp->getLoc(),
411  llvm::formatv("failed to convert memref type: {0}",
412  subViewOp.getType()));
413  }
414 
415  Location loc = subViewOp.getLoc();
416  Type convertedElementType = newTy.getElementType();
417  Type oldElementType = subViewOp.getType().getElementType();
418  int srcBits = oldElementType.getIntOrFloatBitWidth();
419  int dstBits = convertedElementType.getIntOrFloatBitWidth();
420  if (dstBits % srcBits != 0)
421  return rewriter.notifyMatchFailure(
422  subViewOp, "only dstBits % srcBits == 0 supported");
423 
424  // Only support stride of 1.
425  if (llvm::any_of(subViewOp.getStaticStrides(),
426  [](int64_t stride) { return stride != 1; })) {
427  return rewriter.notifyMatchFailure(subViewOp->getLoc(),
428  "stride != 1 is not supported");
429  }
430 
431  if (!memref::isStaticShapeAndContiguousRowMajor(subViewOp.getType())) {
432  return rewriter.notifyMatchFailure(
433  subViewOp, "the result memref type is not contiguous");
434  }
435 
436  auto sizes = subViewOp.getStaticSizes();
437  int64_t lastOffset = subViewOp.getStaticOffsets().back();
438  // Only support static sizes and offsets.
439  if (llvm::any_of(
440  sizes, [](int64_t size) { return size == ShapedType::kDynamic; }) ||
441  lastOffset == ShapedType::kDynamic) {
442  return rewriter.notifyMatchFailure(
443  subViewOp->getLoc(), "dynamic size or offset is not supported");
444  }
445 
446  // Transform the offsets, sizes and strides according to the emulation.
447  auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
448  loc, subViewOp.getViewSource());
449 
450  OpFoldResult linearizedIndices;
451  auto strides = stridedMetadata.getConstifiedMixedStrides();
452  memref::LinearizedMemRefInfo linearizedInfo;
453  std::tie(linearizedInfo, linearizedIndices) =
455  rewriter, loc, srcBits, dstBits,
456  stridedMetadata.getConstifiedMixedOffset(),
457  subViewOp.getMixedSizes(), strides,
458  getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
459  rewriter));
460 
461  rewriter.replaceOpWithNewOp<memref::SubViewOp>(
462  subViewOp, newTy, adaptor.getSource(), linearizedIndices,
463  linearizedInfo.linearizedSize, strides.back());
464  return success();
465  }
466 };
467 
468 //===----------------------------------------------------------------------===//
469 // ConvertMemRefCollapseShape
470 //===----------------------------------------------------------------------===//
471 
472 /// Emulating a `memref.collapse_shape` becomes a no-op after emulation given
473 /// that we flatten memrefs to a single dimension as part of the emulation and
474 /// there is no dimension to collapse any further.
475 struct ConvertMemRefCollapseShape final
476  : OpConversionPattern<memref::CollapseShapeOp> {
478 
479  LogicalResult
480  matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
481  ConversionPatternRewriter &rewriter) const override {
482  Value srcVal = adaptor.getSrc();
483  auto newTy = dyn_cast<MemRefType>(srcVal.getType());
484  if (!newTy)
485  return failure();
486 
487  if (newTy.getRank() != 1)
488  return failure();
489 
490  rewriter.replaceOp(collapseShapeOp, srcVal);
491  return success();
492  }
493 };
494 
495 } // end anonymous namespace
496 
497 //===----------------------------------------------------------------------===//
498 // Public Interface Definition
499 //===----------------------------------------------------------------------===//
500 
503  RewritePatternSet &patterns) {
504 
505  // Populate `memref.*` conversion patterns.
506  patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
507  ConvertMemRefAllocation<memref::AllocaOp>,
508  ConvertMemRefCollapseShape, ConvertMemRefLoad,
509  ConvertMemrefStore, ConvertMemRefAssumeAlignment,
510  ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
511  typeConverter, patterns.getContext());
513 }
514 
515 static SmallVector<int64_t> getLinearizedShape(MemRefType ty, int srcBits,
516  int dstBits) {
517  if (ty.getRank() == 0)
518  return {};
519 
520  int64_t linearizedShape = 1;
521  for (auto shape : ty.getShape()) {
522  if (shape == ShapedType::kDynamic)
523  return {ShapedType::kDynamic};
524  linearizedShape *= shape;
525  }
526  int scale = dstBits / srcBits;
527  // Scale the size to the ceilDiv(linearizedShape, scale)
528  // to accomodate all the values.
529  linearizedShape = (linearizedShape + scale - 1) / scale;
530  return {linearizedShape};
531 }
532 
534  arith::NarrowTypeEmulationConverter &typeConverter) {
535  typeConverter.addConversion(
536  [&typeConverter](MemRefType ty) -> std::optional<Type> {
537  auto intTy = dyn_cast<IntegerType>(ty.getElementType());
538  if (!intTy)
539  return ty;
540 
541  unsigned width = intTy.getWidth();
542  unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
543  if (width >= loadStoreWidth)
544  return ty;
545 
546  // Currently only handle innermost stride being 1, checking
547  SmallVector<int64_t> strides;
548  int64_t offset;
549  if (failed(getStridesAndOffset(ty, strides, offset)))
550  return std::nullopt;
551  if (!strides.empty() && strides.back() != 1)
552  return std::nullopt;
553 
554  auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
555  intTy.getSignedness());
556  if (!newElemTy)
557  return std::nullopt;
558 
559  StridedLayoutAttr layoutAttr;
560  // If the offset is 0, we do not need a strided layout as the stride is
561  // 1, so we only use the strided layout if the offset is not 0.
562  if (offset != 0) {
563  if (offset == ShapedType::kDynamic) {
564  layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
565  ArrayRef<int64_t>{1});
566  } else {
567  // Check if the number of bytes are a multiple of the loadStoreWidth
568  // and if so, divide it by the loadStoreWidth to get the offset.
569  if ((offset * width) % loadStoreWidth != 0)
570  return std::nullopt;
571  offset = (offset * width) / loadStoreWidth;
572 
573  layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
574  ArrayRef<int64_t>{1});
575  }
576  }
577 
578  return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
579  newElemTy, layoutAttr, ty.getMemorySpace());
580  });
581 }
static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
When data is loaded/stored in targetBits granularity, but is used in sourceBits granularity (sourceBi...
static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc, OpFoldResult linearizedIndex, int64_t srcBits, int64_t dstBits)
Returns the scaled linearized index based on the srcBits and dstBits sizes.
static SmallVector< int64_t > getLinearizedShape(MemRefType ty, int srcBits, int dstBits)
static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter, memref::ReinterpretCastOp::Adaptor adaptor, memref::ReinterpretCastOp op, MemRefType newTy)
Converts a memref::ReinterpretCastOp to the converted type.
static OpFoldResult getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits, const SmallVector< OpFoldResult > &indices, Value memref)
static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices, int64_t srcBits, int64_t dstBits, Value bitwidthOffset, OpBuilder &builder)
When writing a subbyte size, masked bitwise operations are used to only modify the relevant bits.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:904
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:128
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:242
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:91
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:210
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:126
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
void populateMemRefNarrowTypeEmulationPatterns(arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating memref operations over narrow types with ops over wider types.
void populateMemRefNarrowTypeEmulationConversions(arith::NarrowTypeEmulationConverter &typeConverter)
Appends type conversions for emulating memref operations over narrow types with ops over wider types.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:51
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Definition: MemRefUtils.cpp:23
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
Include the generated interface declarations.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
For a memref with offset, sizes and strides, returns the offset and size to use for the linearized me...
Definition: MemRefUtils.h:45