MLIR  19.0.0git
EmulateNarrowType.cpp
Go to the documentation of this file.
1 //===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
2 //-*-===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/OpDefinition.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include <cassert>
27 #include <type_traits>
28 
29 using namespace mlir;
30 
31 //===----------------------------------------------------------------------===//
32 // Utility functions
33 //===----------------------------------------------------------------------===//
34 
35 /// Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted
36 /// type. The result MemRefType of the old op must have a rank and stride of 1,
37 /// with static offset and size. The number of bits in the offset must evenly
38 /// divide the bitwidth of the new converted type.
39 template <typename MemRefOpTy>
41  typename MemRefOpTy::Adaptor adaptor,
42  MemRefOpTy op, MemRefType newTy) {
43  static_assert(std::is_same<MemRefOpTy, memref::SubViewOp>() ||
44  std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(),
45  "Expected only memref::SubViewOp or memref::ReinterpretCastOp");
46 
47  auto convertedElementType = newTy.getElementType();
48  auto oldElementType = op.getType().getElementType();
49  int srcBits = oldElementType.getIntOrFloatBitWidth();
50  int dstBits = convertedElementType.getIntOrFloatBitWidth();
51  if (dstBits % srcBits != 0) {
52  return rewriter.notifyMatchFailure(op,
53  "only dstBits % srcBits == 0 supported");
54  }
55 
56  // Only support stride of 1.
57  if (llvm::any_of(op.getStaticStrides(),
58  [](int64_t stride) { return stride != 1; })) {
59  return rewriter.notifyMatchFailure(op->getLoc(),
60  "stride != 1 is not supported");
61  }
62 
63  auto sizes = op.getStaticSizes();
64  int64_t offset = op.getStaticOffset(0);
65  // Only support static sizes and offsets.
66  if (llvm::any_of(sizes,
67  [](int64_t size) { return size == ShapedType::kDynamic; }) ||
68  offset == ShapedType::kDynamic) {
69  return rewriter.notifyMatchFailure(
70  op->getLoc(), "dynamic size or offset is not supported");
71  }
72 
73  int elementsPerByte = dstBits / srcBits;
74  if (offset % elementsPerByte != 0) {
75  return rewriter.notifyMatchFailure(
76  op->getLoc(), "offset not multiple of elementsPerByte is not "
77  "supported");
78  }
79 
81  if (sizes.size())
82  size.push_back(ceilDiv(sizes[0], elementsPerByte));
83  offset = offset / elementsPerByte;
84 
85  rewriter.replaceOpWithNewOp<MemRefOpTy>(op, newTy,
86  *adaptor.getODSOperands(0).begin(),
87  offset, size, op.getStaticStrides());
88  return success();
89 }
90 
91 /// When data is loaded/stored in `targetBits` granularity, but is used in
92 /// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
93 /// treated as an array of elements of width `sourceBits`.
94 /// Return the bit offset of the value at position `srcIdx`. For example, if
95 /// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
96 /// located at (x % 2) * 4. Because there are two elements in one i8, and one
97 /// element has 4 bits.
99  int sourceBits, int targetBits,
100  OpBuilder &builder) {
101  assert(targetBits % sourceBits == 0);
102  AffineExpr s0;
103  bindSymbols(builder.getContext(), s0);
104  int scaleFactor = targetBits / sourceBits;
105  AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
106  OpFoldResult offsetVal =
107  affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
108  Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
109  IntegerType dstType = builder.getIntegerType(targetBits);
110  return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
111 }
112 
113 /// When writing a subbyte size, masked bitwise operations are used to only
114 /// modify the relevant bits. This function returns an and mask for clearing
115 /// the destination bits in a subbyte write. E.g., when writing to the second
116 /// i4 in an i32, 0xFFFFFF0F is created.
117 static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
118  int64_t srcBits, int64_t dstBits,
119  Value bitwidthOffset, OpBuilder &builder) {
120  auto dstIntegerType = builder.getIntegerType(dstBits);
121  auto maskRightAlignedAttr =
122  builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
123  Value maskRightAligned = builder.create<arith::ConstantOp>(
124  loc, dstIntegerType, maskRightAlignedAttr);
125  Value writeMaskInverse =
126  builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
127  auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
128  Value flipVal =
129  builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr);
130  return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
131 }
132 
133 /// Returns the scaled linearized index based on the `srcBits` and `dstBits`
134 /// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and
135 /// the returned index has the granularity of `dstBits`
137  OpFoldResult linearizedIndex,
138  int64_t srcBits, int64_t dstBits) {
139  AffineExpr s0;
140  bindSymbols(builder.getContext(), s0);
141  int64_t scaler = dstBits / srcBits;
142  OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
143  builder, loc, s0.floorDiv(scaler), {linearizedIndex});
144  return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
145 }
146 
147 static OpFoldResult
148 getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
149  const SmallVector<OpFoldResult> &indices,
150  Value memref) {
151  auto stridedMetadata =
152  builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
153  OpFoldResult linearizedIndices;
154  std::tie(std::ignore, linearizedIndices) =
156  builder, loc, srcBits, srcBits,
157  stridedMetadata.getConstifiedMixedOffset(),
158  stridedMetadata.getConstifiedMixedSizes(),
159  stridedMetadata.getConstifiedMixedStrides(), indices);
160  return linearizedIndices;
161 }
162 
163 namespace {
164 
165 //===----------------------------------------------------------------------===//
166 // ConvertMemRefAllocation
167 //===----------------------------------------------------------------------===//
168 
169 template <typename OpTy>
170 struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
172 
174  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
175  ConversionPatternRewriter &rewriter) const override {
176  static_assert(std::is_same<OpTy, memref::AllocOp>() ||
177  std::is_same<OpTy, memref::AllocaOp>(),
178  "expected only memref::AllocOp or memref::AllocaOp");
179  auto currentType = cast<MemRefType>(op.getMemref().getType());
180  auto newResultType = dyn_cast<MemRefType>(
181  this->getTypeConverter()->convertType(op.getType()));
182  if (!newResultType) {
183  return rewriter.notifyMatchFailure(
184  op->getLoc(),
185  llvm::formatv("failed to convert memref type: {0}", op.getType()));
186  }
187 
188  // Special case zero-rank memrefs.
189  if (currentType.getRank() == 0) {
190  rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
191  adaptor.getSymbolOperands(),
192  adaptor.getAlignmentAttr());
193  return success();
194  }
195 
196  Location loc = op.getLoc();
197  OpFoldResult zero = rewriter.getIndexAttr(0);
198  SmallVector<OpFoldResult> indices(currentType.getRank(), zero);
199 
200  // Get linearized type.
201  int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
202  int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
203  SmallVector<OpFoldResult> sizes = op.getMixedSizes();
204 
205  memref::LinearizedMemRefInfo linearizedMemRefInfo =
207  rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes);
208  SmallVector<Value> dynamicLinearizedSize;
209  if (!newResultType.hasStaticShape()) {
210  dynamicLinearizedSize.push_back(getValueOrCreateConstantIndexOp(
211  rewriter, loc, linearizedMemRefInfo.linearizedSize));
212  }
213 
214  rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
215  adaptor.getSymbolOperands(),
216  adaptor.getAlignmentAttr());
217  return success();
218  }
219 };
220 
221 //===----------------------------------------------------------------------===//
222 // ConvertMemRefAssumeAlignment
223 //===----------------------------------------------------------------------===//
224 
225 struct ConvertMemRefAssumeAlignment final
226  : OpConversionPattern<memref::AssumeAlignmentOp> {
228 
230  matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
231  ConversionPatternRewriter &rewriter) const override {
232  Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
233  if (!newTy) {
234  return rewriter.notifyMatchFailure(
235  op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
236  op.getMemref().getType()));
237  }
238 
239  rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
240  op, adaptor.getMemref(), adaptor.getAlignmentAttr());
241  return success();
242  }
243 };
244 
245 //===----------------------------------------------------------------------===//
246 // ConvertMemRefLoad
247 //===----------------------------------------------------------------------===//
248 
249 struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
251 
253  matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
254  ConversionPatternRewriter &rewriter) const override {
255  auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
256  auto convertedElementType = convertedType.getElementType();
257  auto oldElementType = op.getMemRefType().getElementType();
258  int srcBits = oldElementType.getIntOrFloatBitWidth();
259  int dstBits = convertedElementType.getIntOrFloatBitWidth();
260  if (dstBits % srcBits != 0) {
261  return rewriter.notifyMatchFailure(
262  op, "only dstBits % srcBits == 0 supported");
263  }
264 
265  Location loc = op.getLoc();
266  // Special case 0-rank memref loads.
267  Value bitsLoad;
268  if (convertedType.getRank() == 0) {
269  bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
270  ValueRange{});
271  } else {
272  // Linearize the indices of the original load instruction. Do not account
273  // for the scaling yet. This will be accounted for later.
274  OpFoldResult linearizedIndices = getLinearizedSrcIndices(
275  rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
276 
277  Value newLoad = rewriter.create<memref::LoadOp>(
278  loc, adaptor.getMemref(),
279  getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
280  dstBits));
281 
282  // Get the offset and shift the bits to the rightmost.
283  // Note, currently only the big-endian is supported.
284  Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices,
285  srcBits, dstBits, rewriter);
286  bitsLoad = rewriter.create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset);
287  }
288 
289  // Get the corresponding bits. If the arith computation bitwidth equals
290  // to the emulated bitwidth, we apply a mask to extract the low bits.
291  // It is not clear if this case actually happens in practice, but we keep
292  // the operations just in case. Otherwise, if the arith computation bitwidth
293  // is different from the emulated bitwidth we truncate the result.
294  Operation *result;
295  auto resultTy = getTypeConverter()->convertType(oldElementType);
296  if (resultTy == convertedElementType) {
297  auto mask = rewriter.create<arith::ConstantOp>(
298  loc, convertedElementType,
299  rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
300 
301  result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
302  } else {
303  result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
304  }
305 
306  rewriter.replaceOp(op, result->getResult(0));
307  return success();
308  }
309 };
310 
311 //===----------------------------------------------------------------------===//
312 // ConvertMemRefReinterpretCast
313 //===----------------------------------------------------------------------===//
314 
315 /// Output types should be at most one dimensional, so only the 0 or 1
316 /// dimensional cases are supported.
317 struct ConvertMemRefReinterpretCast final
318  : OpConversionPattern<memref::ReinterpretCastOp> {
320 
322  matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
323  ConversionPatternRewriter &rewriter) const override {
324  MemRefType newTy =
325  dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
326  if (!newTy) {
327  return rewriter.notifyMatchFailure(
328  op->getLoc(),
329  llvm::formatv("failed to convert memref type: {0}", op.getType()));
330  }
331 
332  // Only support for 0 or 1 dimensional cases.
333  if (op.getType().getRank() > 1) {
334  return rewriter.notifyMatchFailure(
335  op->getLoc(), "subview with rank > 1 is not supported");
336  }
337 
338  return convertCastingOp(rewriter, adaptor, op, newTy);
339  }
340 };
341 
342 //===----------------------------------------------------------------------===//
343 // ConvertMemrefStore
344 //===----------------------------------------------------------------------===//
345 
346 struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
348 
350  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
351  ConversionPatternRewriter &rewriter) const override {
352  auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
353  int srcBits = op.getMemRefType().getElementTypeBitWidth();
354  int dstBits = convertedType.getElementTypeBitWidth();
355  auto dstIntegerType = rewriter.getIntegerType(dstBits);
356  if (dstBits % srcBits != 0) {
357  return rewriter.notifyMatchFailure(
358  op, "only dstBits % srcBits == 0 supported");
359  }
360 
361  Location loc = op.getLoc();
362  Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
363  adaptor.getValue());
364 
365  // Special case 0-rank memref stores. No need for masking.
366  if (convertedType.getRank() == 0) {
367  rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign,
368  extendedInput, adaptor.getMemref(),
369  ValueRange{});
370  rewriter.eraseOp(op);
371  return success();
372  }
373 
374  OpFoldResult linearizedIndices = getLinearizedSrcIndices(
375  rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
376  Value storeIndices = getIndicesForLoadOrStore(
377  rewriter, loc, linearizedIndices, srcBits, dstBits);
378  Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
379  dstBits, rewriter);
380  Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits,
381  dstBits, bitwidthOffset, rewriter);
382  // Align the value to write with the destination bits
383  Value alignedVal =
384  rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
385 
386  // Clear destination bits
387  rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
388  writeMask, adaptor.getMemref(),
389  storeIndices);
390  // Write srcs bits to destination
391  rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
392  alignedVal, adaptor.getMemref(),
393  storeIndices);
394  rewriter.eraseOp(op);
395  return success();
396  }
397 };
398 
399 //===----------------------------------------------------------------------===//
400 // ConvertMemRefSubview
401 //===----------------------------------------------------------------------===//
402 
403 /// Emulating narrow ints on subview have limited support, supporting only
404 /// static offset and size and stride of 1. Ideally, the subview should be
405 /// folded away before running narrow type emulation, and this pattern would
406 /// never run. This pattern is mostly used for testing pruposes.
407 struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
409 
411  matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
412  ConversionPatternRewriter &rewriter) const override {
413  MemRefType newTy =
414  dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
415  if (!newTy) {
416  return rewriter.notifyMatchFailure(
417  op->getLoc(),
418  llvm::formatv("failed to convert memref type: {0}", op.getType()));
419  }
420 
421  // Only support offset for 1-D subview.
422  if (op.getType().getRank() != 1) {
423  return rewriter.notifyMatchFailure(
424  op->getLoc(), "subview with rank > 1 is not supported");
425  }
426 
427  return convertCastingOp(rewriter, adaptor, op, newTy);
428  }
429 };
430 
431 //===----------------------------------------------------------------------===//
432 // ConvertMemRefCollapseShape
433 //===----------------------------------------------------------------------===//
434 
435 /// Emulating a `memref.collapse_shape` becomes a no-op after emulation given
436 /// that we flatten memrefs to a single dimension as part of the emulation and
437 /// there is no dimension to collapse any further.
438 struct ConvertMemRefCollapseShape final
439  : OpConversionPattern<memref::CollapseShapeOp> {
441 
443  matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
444  ConversionPatternRewriter &rewriter) const override {
445  Value srcVal = adaptor.getSrc();
446  auto newTy = dyn_cast<MemRefType>(srcVal.getType());
447  if (!newTy)
448  return failure();
449 
450  if (newTy.getRank() != 1)
451  return failure();
452 
453  rewriter.replaceOp(collapseShapeOp, srcVal);
454  return success();
455  }
456 };
457 
458 } // end anonymous namespace
459 
460 //===----------------------------------------------------------------------===//
461 // Public Interface Definition
462 //===----------------------------------------------------------------------===//
463 
466  RewritePatternSet &patterns) {
467 
468  // Populate `memref.*` conversion patterns.
469  patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
470  ConvertMemRefAllocation<memref::AllocaOp>,
471  ConvertMemRefCollapseShape, ConvertMemRefLoad,
472  ConvertMemrefStore, ConvertMemRefAssumeAlignment,
473  ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
474  typeConverter, patterns.getContext());
476 }
477 
478 static SmallVector<int64_t> getLinearizedShape(MemRefType ty, int srcBits,
479  int dstBits) {
480  if (ty.getRank() == 0)
481  return {};
482 
483  int64_t linearizedShape = 1;
484  for (auto shape : ty.getShape()) {
485  if (shape == ShapedType::kDynamic)
486  return {ShapedType::kDynamic};
487  linearizedShape *= shape;
488  }
489  int scale = dstBits / srcBits;
490  // Scale the size to the ceilDiv(linearizedShape, scale)
491  // to accomodate all the values.
492  linearizedShape = (linearizedShape + scale - 1) / scale;
493  return {linearizedShape};
494 }
495 
497  arith::NarrowTypeEmulationConverter &typeConverter) {
498  typeConverter.addConversion(
499  [&typeConverter](MemRefType ty) -> std::optional<Type> {
500  auto intTy = dyn_cast<IntegerType>(ty.getElementType());
501  if (!intTy)
502  return ty;
503 
504  unsigned width = intTy.getWidth();
505  unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
506  if (width >= loadStoreWidth)
507  return ty;
508 
509  // Currently only handle innermost stride being 1, checking
510  SmallVector<int64_t> strides;
511  int64_t offset;
512  if (failed(getStridesAndOffset(ty, strides, offset)))
513  return std::nullopt;
514  if (!strides.empty() && strides.back() != 1)
515  return std::nullopt;
516 
517  auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
518  intTy.getSignedness());
519  if (!newElemTy)
520  return std::nullopt;
521 
522  StridedLayoutAttr layoutAttr;
523  // If the offset is 0, we do not need a strided layout as the stride is
524  // 1, so we only use the strided layout if the offset is not 0.
525  if (offset != 0) {
526  if (offset == ShapedType::kDynamic) {
527  layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
528  ArrayRef<int64_t>{1});
529  } else {
530  // Check if the number of bytes are a multiple of the loadStoreWidth
531  // and if so, divide it by the loadStoreWidth to get the offset.
532  if ((offset * width) % loadStoreWidth != 0)
533  return std::nullopt;
534  offset = (offset * width) / loadStoreWidth;
535 
536  layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
537  ArrayRef<int64_t>{1});
538  }
539  }
540 
541  return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
542  newElemTy, layoutAttr, ty.getMemorySpace());
543  });
544 }
static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
When data is loaded/stored in targetBits granularity, but is used in sourceBits granularity (sourceBi...
static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc, OpFoldResult linearizedIndex, int64_t srcBits, int64_t dstBits)
Returns the scaled linearized index based on the srcBits and dstBits sizes.
static SmallVector< int64_t > getLinearizedShape(MemRefType ty, int srcBits, int dstBits)
static OpFoldResult getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits, const SmallVector< OpFoldResult > &indices, Value memref)
static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter, typename MemRefOpTy::Adaptor adaptor, MemRefOpTy op, MemRefType newTy)
Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted type.
static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices, int64_t srcBits, int64_t dstBits, Value bitwidthOffset, OpBuilder &builder)
When writing a subbyte size, masked bitwise operations are used to only modify the relevant bits.
Base type for affine expression.
Definition: AffineExpr.h:69
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:883
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1188
void populateMemRefNarrowTypeEmulationPatterns(arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating memref operations over narrow types with ops over wider types.
void populateMemRefNarrowTypeEmulationConversions(arith::NarrowTypeEmulationConverter &typeConverter)
Appends type conversions for emulating memref operations over narrow types with ops over wider types.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:51
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
Definition: MathExtras.h:23
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:363
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:103
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
For a memref with offset, sizes and strides, returns the offset and size to use for the linearized me...
Definition: MemRefUtils.h:45