MLIR  21.0.0git
VectorEmulateNarrowType.cpp
Go to the documentation of this file.
1 //===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to emulate
10 // narrow types that are not supported by the target hardware, e.g. i4, using
11 // wider types, e.g. i8.
12 //
13 /// Currently, only power-of-two integer types are supported. These are
14 /// converted to wider integers that are either 8 bits wide or wider.
15 ///
16 /// TODO: Support for non-powers-of-two.
17 //===----------------------------------------------------------------------===//
18 
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/OpDefinition.h"
31 #include "mlir/IR/TypeUtilities.h"
32 #include "mlir/IR/Value.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <cstdint>
39 #include <optional>
40 
41 using namespace mlir;
42 
43 #define DEBUG_TYPE "vector-narrow-type-emulation"
44 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
45 #define DBGSNL() (llvm::dbgs() << "\n")
46 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
47 
50 
51 //===----------------------------------------------------------------------===//
52 // Utils
53 //===----------------------------------------------------------------------===//
54 
55 /// Returns a compressed mask for the emulated vector. For example, when
56 /// emulating an eight-element `i8` vector with `i32` (i.e. when the source
57 /// elements span two dest elements), this method compresses `vector<8xi1>`
58 /// into `vector<2xi1>`.
59 ///
60 /// The compressed/output mask value is set iff any mask in the corresponding
61 /// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
62 /// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
63 /// following mask:
64 ///
65 /// %mask = [1, 1, 0, 0, 0, 0]
66 ///
67 /// will first be padded in the front with `numFrontPadElems` zeros, and zeros
68 /// will be added in the back to make the number of elements a multiple of
69 /// `numSrcElemsPerDest` (for easier computation). The resulting mask will be:
70 ///
71 /// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
72 ///
73 /// then it will return the following new compressed mask:
74 ///
75 /// %mask = [1, 1, 0, 0]
76 ///
77 /// NOTE: `numFrontPadElems` is assumed to be strictly smaller than
78 /// `numSrcElemsPerDest`.
79 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
80  Location loc, Value mask,
81  int numSrcElems,
82  int numSrcElemsPerDest,
83  int numFrontPadElems = 0) {
84 
85  assert(numFrontPadElems < numSrcElemsPerDest &&
86  "numFrontPadElems must be less than numSrcElemsPerDest");
87 
88  auto numDestElems =
89  (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
90  numSrcElemsPerDest;
91 
92  Operation *maskOp = mask.getDefiningOp();
94  // TODO: add support to `vector.splat`.
95  // Finding the mask creation operation.
96  while (maskOp &&
97  !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
98  maskOp)) {
99  if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
100  maskOp = extractOp.getVector().getDefiningOp();
101  extractOps.push_back(extractOp);
102  }
103  }
104 
105  if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
106  maskOp))
107  return failure();
108 
109  // Computing the "compressed" mask. All the emulation logic (i.e. computing
110  // new mask index) only happens on the last dimension of the vectors.
111  SmallVector<int64_t> maskShape(
112  cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
113  maskShape.back() = numDestElems;
114  auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type());
115  std::optional<Operation *> newMask =
117  .Case<vector::CreateMaskOp>(
118  [&](auto createMaskOp) -> std::optional<Operation *> {
119  OperandRange maskOperands = createMaskOp.getOperands();
120  // The `vector.create_mask` op creates a mask arrangement
121  // without any zeros at the front. Also, because
122  // `numFrontPadElems` is strictly smaller than
123  // `numSrcElemsPerDest`, the compressed mask generated by
124  // padding the original mask by `numFrontPadElems` will not
125  // have any zeros at the front as well.
126  AffineExpr s0;
127  bindSymbols(rewriter.getContext(), s0);
128  s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
129  OpFoldResult origIndex = getAsOpFoldResult(maskOperands.back());
131  rewriter, loc, s0, origIndex);
132  SmallVector<Value> newMaskOperands(maskOperands.drop_back());
133  newMaskOperands.push_back(
134  getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
135  return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
136  newMaskOperands);
137  })
138  .Case<vector::ConstantMaskOp>(
139  [&](auto constantMaskOp) -> std::optional<Operation *> {
140  // Take the shape of mask, compress its trailing dimension:
141  SmallVector<int64_t> maskDimSizes(
142  constantMaskOp.getMaskDimSizes());
143  int64_t &maskIndex = maskDimSizes.back();
144  maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
145  numSrcElemsPerDest);
146  return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
147  maskDimSizes);
148  })
149  .Case<arith::ConstantOp>([&](auto constantOp)
150  -> std::optional<Operation *> {
151  // TODO: Support multiple dimensions.
152  if (maskShape.size() != 1)
153  return std::nullopt;
154  // Rearrange the original mask values to cover the whole potential
155  // loading region. For example, in the case of using byte-size for
156  // emulation, given the following mask:
157  //
158  // %mask = [0, 1, 0, 1, 0, 0]
159  //
160  // With front offset of 1, the mask will be padded 0s in the front
161  // and back so that:
162  // 1. It is aligned with the effective loading bits
163  // 2. Its length is multiple of `numSrcElemPerDest` (and the total
164  // coverage size is mulitiple of bytes). The new mask will be like
165  // this before compressing:
166  //
167  // %new_mask = [0, 0, 1, 0, 1, 0, 0, 0]
168  auto originalMask =
169  cast<DenseIntElementsAttr>(constantOp.getValue());
170  SmallVector<bool> paddedMaskValues(numFrontPadElems, false);
171  paddedMaskValues.append(originalMask.template value_begin<bool>(),
172  originalMask.template value_end<bool>());
173  paddedMaskValues.resize(numDestElems * numSrcElemsPerDest, false);
174 
175  // Compressing by combining every `numSrcElemsPerDest` elements:
176  SmallVector<bool> compressedMaskValues;
177  for (size_t i = 0; i < paddedMaskValues.size();
178  i += numSrcElemsPerDest) {
179  bool combinedValue = false;
180  for (int j = 0; j < numSrcElemsPerDest; ++j) {
181  combinedValue |= paddedMaskValues[i + j];
182  }
183  compressedMaskValues.push_back(combinedValue);
184  }
185  return rewriter.create<arith::ConstantOp>(
186  loc, DenseElementsAttr::get(newMaskType, compressedMaskValues));
187  });
188 
189  if (!newMask)
190  return failure();
191 
192  while (!extractOps.empty()) {
193  newMask = rewriter.create<vector::ExtractOp>(
194  loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition());
195  extractOps.pop_back();
196  }
197 
198  return *newMask;
199 }
200 
201 /// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
202 /// emitting `vector.extract_strided_slice`.
204  Value source, int64_t frontOffset,
205  int64_t subvecSize) {
206  auto vectorType = cast<VectorType>(source.getType());
207  assert(vectorType.getRank() == 1 && "expected 1-D source types");
208  assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
209  "subvector out of bounds");
210 
211  // do not need extraction if the subvector size is the same as the source
212  if (vectorType.getNumElements() == subvecSize)
213  return source;
214 
215  auto offsets = rewriter.getI64ArrayAttr({frontOffset});
216  auto sizes = rewriter.getI64ArrayAttr({subvecSize});
217  auto strides = rewriter.getI64ArrayAttr({1});
218 
219  auto resultVectorType =
220  VectorType::get({subvecSize}, vectorType.getElementType());
221  return rewriter
222  .create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
223  offsets, sizes, strides)
224  ->getResult(0);
225 }
226 
227 /// Inserts 1-D subvector into a 1-D vector by overwriting the elements starting
228 /// at `offset`. it is a wrapper function for emitting
229 /// `vector.insert_strided_slice`.
231  Value src, Value dest, int64_t offset) {
232  [[maybe_unused]] auto srcType = cast<VectorType>(src.getType());
233  [[maybe_unused]] auto destType = cast<VectorType>(dest.getType());
234  assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
235  "expected source and dest to be vector type");
236  auto offsets = rewriter.getI64ArrayAttr({offset});
237  auto strides = rewriter.getI64ArrayAttr({1});
238  return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
239  dest, offsets, strides);
240 }
241 
242 /// Extracts a 1-D subvector from a 1-D `source` vector, with index at `offset`
243 /// and size `numElementsToExtract`, and inserts into the `dest` vector. This
244 /// function emits multiple `vector.extract` and `vector.insert` ops, so only
245 /// use it when `offset` cannot be folded into a constant value.
247  Value source, Value dest,
248  OpFoldResult offset,
249  int64_t numElementsToExtract) {
250  assert(isa<VectorValue>(source) && "expected `source` to be a vector type");
251  for (int i = 0; i < numElementsToExtract; ++i) {
252  Value extractLoc =
253  (i == 0) ? offset.dyn_cast<Value>()
254  : rewriter.create<arith::AddIOp>(
255  loc, rewriter.getIndexType(), offset.dyn_cast<Value>(),
256  rewriter.create<arith::ConstantIndexOp>(loc, i));
257  auto extractOp =
258  rewriter.create<vector::ExtractOp>(loc, source, extractLoc);
259  dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, i);
260  }
261  return dest;
262 }
263 
264 /// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
266  Value source, Value dest,
267  OpFoldResult destOffsetVar,
268  size_t length) {
269  assert(isa<VectorValue>(source) && "expected `source` to be a vector type");
270  assert(length > 0 && "length must be greater than 0");
271  Value destOffsetVal =
272  getValueOrCreateConstantIndexOp(rewriter, loc, destOffsetVar);
273  for (size_t i = 0; i < length; ++i) {
274  auto insertLoc = i == 0
275  ? destOffsetVal
276  : rewriter.create<arith::AddIOp>(
277  loc, rewriter.getIndexType(), destOffsetVal,
278  rewriter.create<arith::ConstantIndexOp>(loc, i));
279  auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
280  dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
281  }
282  return dest;
283 }
284 
285 /// Emulate a vector load for `emulatedElemTy` using `containerElemTy`
286 ///
287 /// Specifically, use `containerElemTy` for loading a vector of
288 /// `emulatedElemTy`. The load location is given by `base` and
289 /// `linearizedIndices`, and the load size is given by
290 /// `numEmulatedElementsToLoad`.
292  Value base,
293  OpFoldResult linearizedIndices,
294  int64_t numContainerElemsToLoad,
295  Type emulatedElemTy,
296  Type containerElemTy) {
297  auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth() /
298  emulatedElemTy.getIntOrFloatBitWidth();
299  auto newLoad = rewriter.create<vector::LoadOp>(
300  loc, VectorType::get(numContainerElemsToLoad, containerElemTy), base,
301  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
302  return rewriter.create<vector::BitCastOp>(
303  loc,
304  VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,
305  emulatedElemTy),
306  newLoad);
307 }
308 
309 /// Downcast two values to `downcastType`, then select values
310 /// based on `mask`, and casts the result to `upcastType`.
312  VectorType downcastType,
313  VectorType upcastType, Value mask,
314  Value trueValue, Value falseValue) {
315  assert(
316  downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
317  upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
318  "expected input and output number of bits to match");
319  if (trueValue.getType() != downcastType) {
320  trueValue = builder.create<vector::BitCastOp>(loc, downcastType, trueValue);
321  }
322  if (falseValue.getType() != downcastType) {
323  falseValue =
324  builder.create<vector::BitCastOp>(loc, downcastType, falseValue);
325  }
326  Value selectedType =
327  builder.create<arith::SelectOp>(loc, mask, trueValue, falseValue);
328  // Upcast the selected value to the new type.
329  return builder.create<vector::BitCastOp>(loc, upcastType, selectedType);
330 }
331 
332 /// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
333 /// byte in `linearizedMemref`, with a mask. The `valueToStore` is a vector of
334 /// subbyte-sized elements, with size of 8 bits, and the mask is used to select
335 /// which elements to store.
336 ///
337 /// Inputs:
338 /// linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>)
339 /// storeIdx = 2
340 /// valueToStore = |3|3|3|3| : vector<4xi2>
341 /// mask = |0|0|1|1| : vector<4xi1>
342 ///
343 /// Result:
344 /// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
345 static void atomicRMW(OpBuilder &builder, Location loc,
346  MemRefValue linearizedMemref, Value storeIdx,
347  VectorValue valueToStore, Value mask) {
348  assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
349 
350  // Create an atomic load-modify-write region using
351  // `memref.generic_atomic_rmw`.
352  auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
353  loc, linearizedMemref, ValueRange{storeIdx});
354  Value origValue = atomicOp.getCurrentValue();
355 
356  OpBuilder::InsertionGuard guard(builder);
357  builder.setInsertionPointToStart(atomicOp.getBody());
358 
359  // Load the original value from memory, and cast it to the original element
360  // type.
361  auto oneElemVecType = VectorType::get({1}, origValue.getType());
362  Value origVecValue = builder.create<vector::FromElementsOp>(
363  loc, oneElemVecType, ValueRange{origValue});
364 
365  // Construct the final masked value and yield it.
366  Value maskedValue =
367  downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
368  oneElemVecType, mask, valueToStore, origVecValue);
369  auto scalarMaskedValue =
370  builder.create<vector::ExtractOp>(loc, maskedValue, 0);
371  builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
372 }
373 
374 /// Generate a non-atomic read-modify-write sequence for storing to the emulated
375 /// type. It has similar logic to `atomicRMWStore`, but without atomicity.
376 static void nonAtomicRMW(OpBuilder &builder, Location loc,
377  MemRefValue linearizedMemref, Value linearizedIndex,
378  VectorValue valueToStore, Value mask) {
379  assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
380 
381  auto oneElemVecType =
382  VectorType::get({1}, linearizedMemref.getType().getElementType());
383  Value origVecValue = builder.create<vector::LoadOp>(
384  loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
385  origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
386  origVecValue);
387 
388  Value maskedValue =
389  downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
390  oneElemVecType, mask, valueToStore, origVecValue);
391  builder.create<vector::StoreOp>(loc, maskedValue, linearizedMemref,
392  linearizedIndex);
393 }
394 
395 /// Extract `sliceNumElements` from source `vector` at `extractOffset`,
396 /// and insert it into an empty vector at `insertOffset`.
397 /// Inputs:
398 /// vec_in = |0|1|2|3| : vector<4xi2>
399 /// extractOffset = 1
400 /// sliceNumElements = 2
401 /// insertOffset = 2
402 /// Output:
403 /// vec_out = |0|0|1|2| : vector<4xi2>
405  Location loc, VectorValue vector,
406  int64_t extractOffset,
407  int64_t sliceNumElements,
408  int64_t insertOffset) {
409  assert(vector.getType().getRank() == 1 && "expected 1-D vector");
410  auto vectorElementType = vector.getType().getElementType();
411  // TODO: update and use `alignedConversionPrecondition` in the place of
412  // these asserts.
413  assert(
414  sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
415  "sliceNumElements * vector element size must be less than or equal to 8");
416  assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
417  "vector element must be a valid sub-byte type");
418  auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
419  auto emptyByteVector = rewriter.create<arith::ConstantOp>(
420  loc, VectorType::get({emulatedPerContainerElem}, vectorElementType),
421  rewriter.getZeroAttr(
422  VectorType::get({emulatedPerContainerElem}, vectorElementType)));
423  auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
424  extractOffset, sliceNumElements);
425  return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
426  insertOffset);
427 }
428 
429 namespace {
430 
431 //===----------------------------------------------------------------------===//
432 // ConvertVectorStore
433 //===----------------------------------------------------------------------===//
434 
435 // TODO: Document-me
436 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
438 
439  ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW)
440  : OpConversionPattern<vector::StoreOp>(context),
441  disableAtomicRMW(disableAtomicRMW) {}
442 
443  LogicalResult
444  matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
445  ConversionPatternRewriter &rewriter) const override {
446 
447  // See #115653
448  if (op.getValueToStore().getType().getRank() != 1)
449  return rewriter.notifyMatchFailure(op,
450  "only 1-D vectors are supported ATM");
451 
452  auto loc = op.getLoc();
453 
454  auto valueToStore = cast<VectorValue>(op.getValueToStore());
455  auto containerElemTy =
456  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
457  Type emulatedElemTy = op.getValueToStore().getType().getElementType();
458  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
459  int containerBits = containerElemTy.getIntOrFloatBitWidth();
460 
461  // Check per-element alignment.
462  if (containerBits % emulatedBits != 0) {
463  return rewriter.notifyMatchFailure(
464  op, "impossible to pack emulated elements into container elements "
465  "(bit-wise misalignment)");
466  }
467  int numSrcElemsPerDest = containerBits / emulatedBits;
468 
469  // Adjust the number of elements to store when emulating narrow types.
470  // Here only the 1-D vector store is considered, and the N-D memref types
471  // should be linearized.
472  // For example, to emulate i4 to i8, the following op:
473  //
474  // vector.store %arg1, %0[%arg2, %arg3] : memref<4x8xi4>, vector<8xi4>
475  //
476  // can be replaced with
477  //
478  // %bitcast = vector.bitcast %arg1 : vector<8xi4> to vector<4xi8>
479  // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
480  // vector<4xi8>
481 
482  auto origElements = valueToStore.getType().getNumElements();
483  bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0;
484 
485  auto stridedMetadata =
486  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
487 
488  OpFoldResult linearizedIndices;
489  memref::LinearizedMemRefInfo linearizedInfo;
490  std::tie(linearizedInfo, linearizedIndices) =
492  rewriter, loc, emulatedBits, containerBits,
493  stridedMetadata.getConstifiedMixedOffset(),
494  stridedMetadata.getConstifiedMixedSizes(),
495  stridedMetadata.getConstifiedMixedStrides(),
496  getAsOpFoldResult(adaptor.getIndices()));
497 
498  std::optional<int64_t> foldedNumFrontPadElems =
499  isAlignedEmulation
500  ? 0
501  : getConstantIntValue(linearizedInfo.intraDataOffset);
502 
503  if (!foldedNumFrontPadElems) {
504  return rewriter.notifyMatchFailure(
505  op, "subbyte store emulation: dynamic front padding size is "
506  "not yet implemented");
507  }
508 
509  auto memrefBase = cast<MemRefValue>(adaptor.getBase());
510 
511  // Conditions when atomic RMWs are not needed:
512  // 1. The source vector size (in bits) is a multiple of byte size.
513  // 2. The address of the store is aligned to the emulated width boundary.
514  //
515  // For example, to store a vector<4xi2> to <13xi2> at offset 4, does not
516  // need unaligned emulation because the store address is aligned and the
517  // source is a whole byte.
518  bool emulationRequiresPartialStores =
519  !isAlignedEmulation || *foldedNumFrontPadElems != 0;
520  if (!emulationRequiresPartialStores) {
521  // Basic case: storing full bytes.
522  auto numElements = origElements / numSrcElemsPerDest;
523  auto bitCast = rewriter.create<vector::BitCastOp>(
524  loc, VectorType::get(numElements, containerElemTy),
525  op.getValueToStore());
526  rewriter.replaceOpWithNewOp<vector::StoreOp>(
527  op, bitCast.getResult(), memrefBase,
528  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
529  return success();
530  }
531 
532  // Next, handle the case when sub-byte read-modify-write
533  // sequences are needed to emulate a vector store.
534  // Here is an example:
535  //
536  // Vector to store: vector<7xi2>
537  // Value to store: 11 11 11 11 11 11 11 (all ones)
538  //
539  // Destination: memref<12xi2>
540  // Store offset: 2 (i.e. 4 bits into the 1st emulated byte).
541  //
542  // Input MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
543  //
544  // Destination memref before:
545  //
546  // Byte 0 Byte 1 Byte 2
547  // +----------+----------+----------+
548  // | 00000000 | 00000000 | 00000000 |
549  // +----------+----------+----------+
550  //
551  // Destination memref after:
552  //
553  // Byte 0 Byte 1 Byte 2
554  // +----------+----------+----------+
555  // | 00001111 | 11111111 | 11000000 |
556  // +----------+----------+----------+
557  //
558  // Note, stores to Byte 1 are "full-width" and hence don't require RMW (no
559  // need for atomicity). Stores to Bytes 0 and Byte 2 are "partial", hence
560  // requiring RMW access (atomicity is required).
561 
562  // The index into the target memref we are storing to.
563  Value currentDestIndex =
564  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
565  // The index into the source vector we are currently processing.
566  auto currentSourceIndex = 0;
567 
568  // Build a mask used for rmw.
569  auto subWidthStoreMaskType =
570  VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
571 
572  auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
573 
574  // 1. Partial width store for the leading byte.
575  // When the store address is not aligned to emulated width boundary, deal
576  // with the unaligned part so that the rest elements are aligned to width
577  // boundary.
578  auto frontSubWidthStoreElem =
579  (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
580  if (frontSubWidthStoreElem > 0) {
581  SmallVector<bool> frontMaskValues(numSrcElemsPerDest, false);
582  if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
583  std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
584  origElements, true);
585  frontSubWidthStoreElem = origElements;
586  } else {
587  std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
588  *foldedNumFrontPadElems, true);
589  }
590  auto frontMask = rewriter.create<arith::ConstantOp>(
591  loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
592 
593  currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
594  auto value =
595  extractSliceIntoByte(rewriter, loc, valueToStore, 0,
596  frontSubWidthStoreElem, *foldedNumFrontPadElems);
597 
598  storeFunc(rewriter, loc, memrefBase, currentDestIndex,
599  cast<VectorValue>(value), frontMask.getResult());
600  }
601 
602  if (currentSourceIndex >= origElements) {
603  rewriter.eraseOp(op);
604  return success();
605  }
606 
607  // Increment the destination index by 1 to align to the emulated width
608  // boundary.
609  auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
610  currentDestIndex = rewriter.create<arith::AddIOp>(
611  loc, rewriter.getIndexType(), currentDestIndex, constantOne);
612 
613  // 2. Full width store for the inner output bytes.
614  // After the previous step, the store address is aligned to the emulated
615  // width boundary.
616  int64_t fullWidthStoreSize =
617  (origElements - currentSourceIndex) / numSrcElemsPerDest;
618  int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
619  if (fullWidthStoreSize > 0) {
620  auto fullWidthStorePart = staticallyExtractSubvector(
621  rewriter, loc, valueToStore, currentSourceIndex,
622  numNonFullWidthElements);
623 
624  auto originType = cast<VectorType>(fullWidthStorePart.getType());
625  auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
626  auto storeType = VectorType::get(
627  {originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
628  auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
629  fullWidthStorePart);
630  rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
631  currentDestIndex);
632 
633  currentSourceIndex += numNonFullWidthElements;
634  currentDestIndex = rewriter.create<arith::AddIOp>(
635  loc, rewriter.getIndexType(), currentDestIndex,
636  rewriter.create<arith::ConstantIndexOp>(loc, fullWidthStoreSize));
637  }
638 
639  // 3. Partial width store for the trailing output byte.
640  // It is needed when the residual length is smaller than the emulated width,
641  // which is not covered in step 2 above.
642  auto remainingElements = origElements - currentSourceIndex;
643  if (remainingElements != 0) {
644  auto subWidthStorePart =
645  extractSliceIntoByte(rewriter, loc, cast<VectorValue>(valueToStore),
646  currentSourceIndex, remainingElements, 0);
647 
648  // Generate back mask.
649  auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0);
650  std::fill_n(maskValues.begin(), remainingElements, 1);
651  auto backMask = rewriter.create<arith::ConstantOp>(
652  loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
653 
654  storeFunc(rewriter, loc, memrefBase, currentDestIndex,
655  cast<VectorValue>(subWidthStorePart), backMask.getResult());
656  }
657 
658  rewriter.eraseOp(op);
659  return success();
660  }
661 
662 private:
663  const bool disableAtomicRMW;
664 };
665 
666 //===----------------------------------------------------------------------===//
667 // ConvertVectorMaskedStore
668 //===----------------------------------------------------------------------===//
669 
670 // TODO: Document-me
671 struct ConvertVectorMaskedStore final
672  : OpConversionPattern<vector::MaskedStoreOp> {
674 
675  LogicalResult
676  matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
677  ConversionPatternRewriter &rewriter) const override {
678 
679  // See #115653
680  if (op.getValueToStore().getType().getRank() != 1)
681  return rewriter.notifyMatchFailure(op,
682  "only 1-D vectors are supported ATM");
683 
684  auto loc = op.getLoc();
685  auto containerElemTy =
686  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
687  Type emulatedElemTy = op.getValueToStore().getType().getElementType();
688  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
689  int containerBits = containerElemTy.getIntOrFloatBitWidth();
690 
691  // Check per-element alignment.
692  if (containerBits % emulatedBits != 0) {
693  return rewriter.notifyMatchFailure(
694  op, "impossible to pack emulated elements into container elements "
695  "(bit-wise misalignment)");
696  }
697 
698  int emulatedPerContainerElem = containerBits / emulatedBits;
699  int origElements = op.getValueToStore().getType().getNumElements();
700  if (origElements % emulatedPerContainerElem != 0)
701  return failure();
702 
703  auto stridedMetadata =
704  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
705  OpFoldResult linearizedIndicesOfr;
706  memref::LinearizedMemRefInfo linearizedInfo;
707  std::tie(linearizedInfo, linearizedIndicesOfr) =
709  rewriter, loc, emulatedBits, containerBits,
710  stridedMetadata.getConstifiedMixedOffset(),
711  stridedMetadata.getConstifiedMixedSizes(),
712  stridedMetadata.getConstifiedMixedStrides(),
713  getAsOpFoldResult(adaptor.getIndices()));
714  Value linearizedIndices =
715  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
716 
717  // Load the whole data and use arith.select to handle the corner cases.
718  //
719  // As an example, for this masked store of i4 values:
720  //
721  // vector.maskedstore %0[%c0, %c0], %mask, %val_to_store
722  //
723  // and given these input values:
724  //
725  // %mask = [0, 1, 1, 1, 1, 0, 0, 0] (8 * i1)
726  // %0[%c0, %c0] =
727  // [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
728  // %val_to_store =
729  // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] (8 * i4)
730  //
731  // we'll have the following i4 output:
732  //
733  // expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8]
734  //
735  // Emulating the above using i8 will give:
736  //
737  // %compressed_mask = [1, 1, 1, 0] (4 * i1)
738  // %maskedload = [0x12, 0x34, 0x56, 0x00] (4 * i8)
739  // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] (8 * i4)
740  // %select_using_shifted_mask =
741  // [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0] (8 * i4)
742  // %packed_data = [0x1A, 0xBC, 0xD6, 0x00] (4 * i8)
743  //
744  // Using the compressed mask to store %packed_data results in expected
745  // output.
746  //
747  // FIXME: Make an example based on the comment above work (see #115460 for
748  // reproducer).
749  FailureOr<Operation *> newMask = getCompressedMaskOp(
750  rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
751  if (failed(newMask))
752  return failure();
753 
754  auto numElements = (origElements + emulatedPerContainerElem - 1) /
755  emulatedPerContainerElem;
756  auto newType = VectorType::get(numElements, containerElemTy);
757  auto passThru = rewriter.create<arith::ConstantOp>(
758  loc, newType, rewriter.getZeroAttr(newType));
759 
760  auto newLoad = rewriter.create<vector::MaskedLoadOp>(
761  loc, newType, adaptor.getBase(), linearizedIndices,
762  newMask.value()->getResult(0), passThru);
763 
764  auto newBitCastType =
765  VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
766  Value valueToStore =
767  rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
768  valueToStore = rewriter.create<arith::SelectOp>(
769  loc, op.getMask(), op.getValueToStore(), valueToStore);
770  valueToStore =
771  rewriter.create<vector::BitCastOp>(loc, newType, valueToStore);
772 
773  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
774  op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
775  valueToStore);
776  return success();
777  }
778 };
779 
780 //===----------------------------------------------------------------------===//
781 // ConvertVectorLoad
782 //===----------------------------------------------------------------------===//
783 
784 // TODO: Document-me
785 struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
787 
788  LogicalResult
789  matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
790  ConversionPatternRewriter &rewriter) const override {
791 
792  // See #115653
793  if (op.getVectorType().getRank() != 1)
794  return rewriter.notifyMatchFailure(op,
795  "only 1-D vectors are supported ATM");
796 
797  auto loc = op.getLoc();
798  auto containerElemTy =
799  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
800  Type emulatedElemTy = op.getType().getElementType();
801  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
802  int containerBits = containerElemTy.getIntOrFloatBitWidth();
803 
804  // Check per-element alignment.
805  if (containerBits % emulatedBits != 0) {
806  return rewriter.notifyMatchFailure(
807  op, "impossible to pack emulated elements into container elements "
808  "(bit-wise misalignment)");
809  }
810  int emulatedPerContainerElem = containerBits / emulatedBits;
811 
812  // Adjust the number of elements to load when emulating narrow types,
813  // and then cast back to the original type with vector.bitcast op.
814  // Here only the 1-D vector load is considered, and the N-D memref types
815  // should be linearized.
816  // For example, to emulate i4 to i8, the following op:
817  //
818  // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
819  //
820  // can be replaced with
821  //
822  // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
823  // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
824  //
825  // There are cases where the number of elements to load is not byte-aligned,
826  // for example:
827  //
828  // %1 = vector.load %0[%c1, %c0] : memref<3x3xi2>, vector<3xi2>
829  //
830  // we will have to load extra bytes and extract the exact slice in between.
831  //
832  // %1 = vector.load %0[%c2] : memref<3xi8>, vector<2xi8>
833  // %2 = vector.bitcast %1 : vector<2xi8> to vector<8xi2>
834  // %3 = vector.extract_strided_slice %1 {offsets = [2], sizes = [3], strides
835  // = [1]}
836  // : vector<8xi2> to vector<3xi2>
837  //
838  // TODO: Currently the extract_strided_slice's attributes must be known at
839  // compile time as they must be constants.
840 
841  auto origElements = op.getVectorType().getNumElements();
842  // Note, per-element-alignment was already verified above.
843  bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
844 
845  auto stridedMetadata =
846  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
847 
848  OpFoldResult linearizedIndices;
849  memref::LinearizedMemRefInfo linearizedInfo;
850  std::tie(linearizedInfo, linearizedIndices) =
852  rewriter, loc, emulatedBits, containerBits,
853  stridedMetadata.getConstifiedMixedOffset(),
854  stridedMetadata.getConstifiedMixedSizes(),
855  stridedMetadata.getConstifiedMixedStrides(),
856  getAsOpFoldResult(adaptor.getIndices()));
857 
858  std::optional<int64_t> foldedIntraVectorOffset =
859  isFullyAligned ? 0
860  : getConstantIntValue(linearizedInfo.intraDataOffset);
861 
862  // Always load enough elements which can cover the original elements.
863  int64_t maxintraDataOffset =
864  foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
865  auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
866  emulatedPerContainerElem);
867  Value result =
868  emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
869  numElements, emulatedElemTy, containerElemTy);
870 
871  if (!foldedIntraVectorOffset) {
872  auto resultVector = rewriter.create<arith::ConstantOp>(
873  loc, op.getType(), rewriter.getZeroAttr(op.getType()));
875  rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
876  linearizedInfo.intraDataOffset, origElements);
877  } else if (!isFullyAligned) {
879  rewriter, loc, result, *foldedIntraVectorOffset, origElements);
880  }
881  rewriter.replaceOp(op, result);
882  return success();
883  }
884 };
885 
886 //===----------------------------------------------------------------------===//
887 // ConvertVectorMaskedLoad
888 //===----------------------------------------------------------------------===//
889 
890 // TODO: Document-me
891 struct ConvertVectorMaskedLoad final
892  : OpConversionPattern<vector::MaskedLoadOp> {
894 
895  LogicalResult
896  matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
897  ConversionPatternRewriter &rewriter) const override {
898  // See #115653
899  if (op.getVectorType().getRank() != 1)
900  return rewriter.notifyMatchFailure(op,
901  "only 1-D vectors are supported ATM");
902 
903  auto loc = op.getLoc();
904 
905  auto containerElemTy =
906  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
907  Type emulatedElemTy = op.getType().getElementType();
908  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
909  int containerBits = containerElemTy.getIntOrFloatBitWidth();
910 
911  // Check per-element alignment.
912  if (containerBits % emulatedBits != 0) {
913  return rewriter.notifyMatchFailure(
914  op, "impossible to pack emulated elements into container elements "
915  "(bit-wise misalignment)");
916  }
917  int emulatedPerContainerElem = containerBits / emulatedBits;
918 
919  // Adjust the number of elements to load when emulating narrow types,
920  // and then cast back to the original type with vector.bitcast op.
921  // For example, to emulate i4 to i8, the following op:
922  //
923  // %mask = vector.constant_mask [3] : vector<6xi1>
924  // %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
925  // memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
926  //
927  // can be replaced with
928  //
929  // %new_mask = vector.constant_mask [2] : vector<3xi1>
930  // %new_pass_thru = vector.bitcast %pass_thru :
931  // vector<6xi4> to vector<3xi8>
932  // %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru :
933  // memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8>
934  // %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4>
935  //
936  // Since we are effectively loading 16 bits (2xi8) from the memref with the
937  // new mask, while originally we only wanted to effectively load 12 bits
938  // (3xi4) from the memref, we need to set the second half of the last i8
939  // that was effectively loaded (i.e. the second i8) to %pass_thru.
940  //
941  // %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4>
942  //
943  // Given these input values:
944  // %mask = [1, 1, 1, 0, 0, 0]
945  // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
946  // %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
947  //
948  // we'll have:
949  //
950  // expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
951  //
952  // %new_mask = [1, 1, 0]
953  // %new_pass_thru = [0x78, 0x9A, 0xBC]
954  // %1 = [0x12, 0x34, 0xBC]
955  // %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC]
956  // %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
957  //
958  // TODO: Currently, only the even number of elements loading is supported.
959  // To deal with the odd number of elements, one has to extract the
960  // subvector at the proper offset after bit-casting.
961  auto origType = op.getVectorType();
962  auto origElements = origType.getNumElements();
963  bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
964 
965  auto stridedMetadata =
966  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
967  OpFoldResult linearizedIndices;
968  memref::LinearizedMemRefInfo linearizedInfo;
969  std::tie(linearizedInfo, linearizedIndices) =
971  rewriter, loc, emulatedBits, containerBits,
972  stridedMetadata.getConstifiedMixedOffset(),
973  stridedMetadata.getConstifiedMixedSizes(),
974  stridedMetadata.getConstifiedMixedStrides(),
975  getAsOpFoldResult(adaptor.getIndices()));
976 
977  std::optional<int64_t> foldedIntraVectorOffset =
978  isAlignedEmulation
979  ? 0
980  : getConstantIntValue(linearizedInfo.intraDataOffset);
981 
982  int64_t maxIntraDataOffset =
983  foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
984  FailureOr<Operation *> newMask =
985  getCompressedMaskOp(rewriter, loc, op.getMask(), origElements,
986  emulatedPerContainerElem, maxIntraDataOffset);
987  if (failed(newMask))
988  return failure();
989 
990  Value passthru = op.getPassThru();
991 
992  auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
993  emulatedPerContainerElem);
994  auto loadType = VectorType::get(numElements, containerElemTy);
995  auto newBitcastType =
996  VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
997 
998  auto emptyVector = rewriter.create<arith::ConstantOp>(
999  loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
1000  if (!foldedIntraVectorOffset) {
1001  passthru = dynamicallyInsertSubVector(
1002  rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
1003  origElements);
1004  } else if (!isAlignedEmulation) {
1005  passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
1006  *foldedIntraVectorOffset);
1007  }
1008  auto newPassThru =
1009  rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
1010 
1011  // Generating the new masked load.
1012  auto newLoad = rewriter.create<vector::MaskedLoadOp>(
1013  loc, loadType, adaptor.getBase(),
1014  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
1015  newMask.value()->getResult(0), newPassThru);
1016 
1017  // Setting the part that originally was not effectively loaded from memory
1018  // to pass through.
1019  auto bitCast =
1020  rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
1021 
1022  Value mask = op.getMask();
1023  auto newSelectMaskType = VectorType::get(
1024  numElements * emulatedPerContainerElem, rewriter.getI1Type());
1025  // TODO: try to fold if op's mask is constant
1026  auto emptyMask = rewriter.create<arith::ConstantOp>(
1027  loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
1028  if (!foldedIntraVectorOffset) {
1029  mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
1030  linearizedInfo.intraDataOffset,
1031  origElements);
1032  } else if (!isAlignedEmulation) {
1033  mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
1034  *foldedIntraVectorOffset);
1035  }
1036 
1037  Value result =
1038  rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
1039  if (!foldedIntraVectorOffset) {
1040  result = dynamicallyExtractSubVector(
1041  rewriter, loc, result, op.getPassThru(),
1042  linearizedInfo.intraDataOffset, origElements);
1043  } else if (!isAlignedEmulation) {
1044  result = staticallyExtractSubvector(
1045  rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1046  }
1047  rewriter.replaceOp(op, result);
1048 
1049  return success();
1050  }
1051 };
1052 
1053 //===----------------------------------------------------------------------===//
1054 // ConvertVectorTransferRead
1055 //===----------------------------------------------------------------------===//
1056 
1057 // TODO: Document-me
1058 struct ConvertVectorTransferRead final
1059  : OpConversionPattern<vector::TransferReadOp> {
1061 
1062  LogicalResult
1063  matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
1064  ConversionPatternRewriter &rewriter) const override {
1065 
1066  // See #115653
1067  if (op.getVectorType().getRank() != 1)
1068  return rewriter.notifyMatchFailure(op,
1069  "only 1-D vectors are supported ATM");
1070 
1071  auto loc = op.getLoc();
1072  auto containerElemTy =
1073  cast<MemRefType>(adaptor.getSource().getType()).getElementType();
1074  Type emulatedElemTy = op.getType().getElementType();
1075  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
1076  int containerBits = containerElemTy.getIntOrFloatBitWidth();
1077 
1078  // Check per-element alignment.
1079  if (containerBits % emulatedBits != 0) {
1080  return rewriter.notifyMatchFailure(
1081  op, "impossible to pack emulated elements into container elements "
1082  "(bit-wise misalignment)");
1083  }
1084  int emulatedPerContainerElem = containerBits / emulatedBits;
1085 
1086  auto origElements = op.getVectorType().getNumElements();
1087 
1088  // Note, per-element-alignment was already verified above.
1089  bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
1090 
1091  auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
1092  adaptor.getPadding());
1093 
1094  auto stridedMetadata =
1095  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
1096 
1097  OpFoldResult linearizedIndices;
1098  memref::LinearizedMemRefInfo linearizedInfo;
1099  std::tie(linearizedInfo, linearizedIndices) =
1101  rewriter, loc, emulatedBits, containerBits,
1102  stridedMetadata.getConstifiedMixedOffset(),
1103  stridedMetadata.getConstifiedMixedSizes(),
1104  stridedMetadata.getConstifiedMixedStrides(),
1105  getAsOpFoldResult(adaptor.getIndices()));
1106 
1107  std::optional<int64_t> foldedIntraVectorOffset =
1108  isFullyAligned ? 0
1109  : getConstantIntValue(linearizedInfo.intraDataOffset);
1110 
1111  int64_t maxIntraDataOffset =
1112  foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1113  auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1114  emulatedPerContainerElem);
1115 
1116  auto newRead = rewriter.create<vector::TransferReadOp>(
1117  loc, VectorType::get(numElements, containerElemTy), adaptor.getSource(),
1118  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
1119  newPadding);
1120 
1121  auto bitCast = rewriter.create<vector::BitCastOp>(
1122  loc,
1123  VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
1124  newRead);
1125 
1126  Value result = bitCast->getResult(0);
1127  if (!foldedIntraVectorOffset) {
1128  auto zeros = rewriter.create<arith::ConstantOp>(
1129  loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1130  result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
1131  linearizedInfo.intraDataOffset,
1132  origElements);
1133  } else if (!isFullyAligned) {
1134  result = staticallyExtractSubvector(
1135  rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1136  }
1137  rewriter.replaceOp(op, result);
1138 
1139  return success();
1140  }
1141 };
1142 } // end anonymous namespace
1143 
1144 //===----------------------------------------------------------------------===//
1145 // RewriteBitCastOfTruncI
1146 //===----------------------------------------------------------------------===//
1147 
1148 namespace {
1149 
1150 /// Helper struct to keep track of the provenance of a contiguous set of bits
1151 /// in a source vector.
1152 struct SourceElementRange {
1153  /// The index of the source vector element that contributes bits to *this.
1154  int64_t sourceElementIdx;
1155  /// The range of bits in the source vector element that contribute to *this.
1156  int64_t sourceBitBegin;
1157  int64_t sourceBitEnd;
1158 };
1159 
1160 struct SourceElementRangeList : public SmallVector<SourceElementRange> {
1161  /// Given the index of a SourceElementRange in the SourceElementRangeList,
1162  /// compute the amount of bits that need to be shifted to the left to get the
1163  /// bits in their final location. This shift amount is simply the sum of the
1164  /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
1165  /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
1166  int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
1167  int64_t res = 0;
1168  for (int64_t i = 0; i < shuffleIdx; ++i)
1169  res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
1170  return res;
1171  }
1172 };
1173 
1174 /// Helper struct to enumerate the source elements and bit ranges that are
1175 /// involved in a bitcast operation.
1176 /// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
1177 /// any 1-D vector shape and any source/target bitwidths.
1178 /// This creates and holds a mapping of the form:
1179 /// [dstVectorElementJ] ==
1180 /// [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ]
1181 /// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as:
1182 /// [0] = {0, [0-8)}
1183 /// [1] = {0, [8-16)}
1184 /// [2] = {0, [16-24)}
1185 /// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
1186 /// [0] = {0, [0, 10)}, {1, [0, 5)}
1187 /// [1] = {1, [5, 10)}, {2, [0, 10)}
1188 struct BitCastBitsEnumerator {
1189  BitCastBitsEnumerator(VectorType sourceVectorType,
1190  VectorType targetVectorType);
1191 
1192  int64_t getMaxNumberOfEntries() {
1193  int64_t numVectors = 0;
1194  for (const auto &l : sourceElementRanges)
1195  numVectors = std::max(numVectors, (int64_t)l.size());
1196  return numVectors;
1197  }
1198 
1199  VectorType sourceVectorType;
1200  VectorType targetVectorType;
1201  SmallVector<SourceElementRangeList> sourceElementRanges;
1202 };
1203 
1204 /// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
1205 /// advantage of high-level information to avoid leaving LLVM to scramble with
1206 /// peephole optimizations.
1207 /// BitCastBitsEnumerator encodes for each element of the target vector the
1208 /// provenance of the bits in the source vector. We can "transpose" this
1209 /// information to build a sequence of shuffles and bitwise ops that will
1210 /// produce the desired result.
1211 //
1212 /// Consider the following motivating example:
1213 /// ```
1214 /// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
1215 /// ```
1216 //
1217 /// BitCastBitsEnumerator contains the following information:
1218 /// ```
1219 /// { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
1220 /// { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
1221 /// { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
1222 /// { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
1223 /// { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
1224 /// { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
1225 /// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
1226 /// {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
1227 /// {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
1228 /// {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
1229 /// {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
1230 /// {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
1231 /// {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
1232 /// {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
1233 /// {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
1234 /// {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
1235 /// {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
1236 /// {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
1237 /// {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
1238 /// {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
1239 /// ```
1240 ///
1241 /// In the above, each row represents one target vector element and each
1242 /// column represents one bit contribution from a source vector element.
1243 /// The algorithm creates vector.shuffle operations (in this case there are 3
1244 /// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
1245 /// algorithm populates the bits as follows:
1246 /// ```
1247 /// src bits 0 ...
1248 /// 1st shuffle |xxxxx |xx |...
1249 /// 2nd shuffle | xxx| xxxxx |...
1250 /// 3rd shuffle | | x|...
1251 /// ```
1252 //
1253 /// The algorithm proceeds as follows:
1254 /// 1. for each vector.shuffle, collect the source vectors that participate in
1255 /// this shuffle. One source vector per target element of the resulting
1256 /// vector.shuffle. If there is no source element contributing bits for the
1257 /// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
1258 /// 2 columns).
1259 /// 2. represent the bitrange in the source vector as a mask. If there is no
1260 /// source element contributing bits for the current vector.shuffle, take 0.
1261 /// 3. shift right by the proper amount to align the source bitrange at
1262 /// position 0. This is exactly the low end of the bitrange. For instance,
1263 /// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
1264 /// shift right by 3 to get the bits contributed by the source element #1
1265 /// into position 0.
1266 /// 4. shift left by the proper amount to to align to the desired position in
1267 /// the result element vector. For instance, the contribution of the second
1268 /// source element for the first row needs to be shifted by `5` to form the
1269 /// first i8 result element.
1270 ///
1271 /// Eventually, we end up building the sequence
1272 /// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
1273 /// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
1274 /// bits extracted from the source vector (i.e. the `shuffle -> and` part).
1275 struct BitCastRewriter {
1276  /// Helper metadata struct to hold the static quantities for the rewrite.
1277  struct Metadata {
1278  SmallVector<int64_t> shuffles;
1279  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1280  };
1281 
1282  BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
1283 
1284  /// Verify that general preconditions for the rewrite are met.
1285  LogicalResult commonPrecondition(PatternRewriter &rewriter,
1286  VectorType preconditionType, Operation *op);
1287 
1288  /// Precompute the metadata for the rewrite.
1290  precomputeMetadata(IntegerType shuffledElementType);
1291 
1292  /// Rewrite one step of the sequence:
1293  /// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
1294  Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
1295  Value initialValue, Value runningResult,
1296  const BitCastRewriter::Metadata &metadata);
1297 
1298 private:
1299  /// Underlying enumerator that encodes the provenance of the bits in the each
1300  /// element of the result vector.
1301  BitCastBitsEnumerator enumerator;
1302 };
1303 
1304 } // namespace
1305 
1306 [[maybe_unused]] static raw_ostream &
1307 operator<<(raw_ostream &os, const SmallVector<SourceElementRangeList> &vec) {
1308  for (const auto &l : vec) {
1309  for (auto it : llvm::enumerate(l)) {
1310  os << "{ " << it.value().sourceElementIdx << ": b@["
1311  << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
1312  << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
1313  }
1314  os << "\n";
1315  }
1316  return os;
1317 }
1318 
1319 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1320  VectorType targetVectorType)
1321  : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1322 
1323  assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1324  "requires -D non-scalable vector type");
1325  assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1326  "requires -D non-scalable vector type");
1327  int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1328  int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1329  LDBG("sourceVectorType: " << sourceVectorType);
1330 
1331  int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1332  int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1333  LDBG("targetVectorType: " << targetVectorType);
1334 
1335  int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1336  (void)mostMinorSourceDim;
1337  assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1338  "source and target bitwidths must match");
1339 
1340  // Prepopulate one source element range per target element.
1341  sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
1342  for (int64_t resultBit = 0; resultBit < bitwidth;) {
1343  int64_t resultElement = resultBit / targetBitWidth;
1344  int64_t resultBitInElement = resultBit % targetBitWidth;
1345  int64_t sourceElementIdx = resultBit / sourceBitWidth;
1346  int64_t sourceBitInElement = resultBit % sourceBitWidth;
1347  int64_t step = std::min(sourceBitWidth - sourceBitInElement,
1348  targetBitWidth - resultBitInElement);
1349  sourceElementRanges[resultElement].push_back(
1350  {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1351  resultBit += step;
1352  }
1353 }
1354 
1355 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1356  VectorType targetVectorType)
1357  : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1358  LDBG("\n" << enumerator.sourceElementRanges);
1359 }
1360 
1361 /// Verify that the precondition type meets the common preconditions for any
1362 /// conversion.
1363 static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
1364  VectorType preconditionType,
1365  Operation *op) {
1366  if (!preconditionType || preconditionType.isScalable())
1367  return rewriter.notifyMatchFailure(op, "scalable vector");
1368 
1369  // TODO: consider relaxing this restriction in the future if we find ways
1370  // to really work with subbyte elements across the MLIR/LLVM boundary.
1371  unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1372  if (bitwidth % 8 != 0)
1373  return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
1374 
1375  return success();
1376 }
1377 
1378 LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
1379  VectorType preconditionType,
1380  Operation *op) {
1381  if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1382  return rewriter.notifyMatchFailure(op, "types are not vector");
1383 
1384  if (!preconditionType || preconditionType.getRank() != 1)
1385  return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
1386 
1387  return commonConversionPrecondition(rewriter, preconditionType, op);
1388 }
1389 
1390 /// Verify that `subByteVecType` and `dstType` are aligned. Alignment
1391 /// means that:
1392 /// 1. The `dstType` element type is a multiple of the
1393 /// `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
1394 /// is not supported). Let this multiple be `N`.
1395 /// 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
1396 /// multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
1397 /// not supported).
1398 ///
1399 /// NOTE: This method assumes that common conversion preconditions are met. In
1400 /// particular, the element type of `dstType` is assumed to be a multi-byte
1401 /// type (e.g. i8, i16, i32).
1402 static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
1403  VectorType subByteVecType,
1404  VectorType dstType,
1405  Operation *op) {
1406  if (!subByteVecType || !dstType)
1407  return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
1408  unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
1409  unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
1410 
1411  if (dstElemBitwidth < 8)
1412  return rewriter.notifyMatchFailure(
1413  op, "the bitwidth of dstType must be greater than or equal to 8");
1414  if (dstElemBitwidth % srcElemBitwidth != 0)
1415  return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
1416  if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
1417  return rewriter.notifyMatchFailure(
1418  op, "only src bitwidth of 2 or 4 is supported at this moment");
1419 
1420  const int numSrcElemsPerByte = 8 / srcElemBitwidth;
1421  if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0)
1422  return rewriter.notifyMatchFailure(
1423  op, "the trailing dimension of the input vector of sub-bytes must be a "
1424  "multiple of 8 / <sub-byte-width>");
1425 
1426  return success();
1427 }
1428 
1430 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1432  for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1433  shuffleIdx < e; ++shuffleIdx) {
1434  SmallVector<int64_t> shuffles;
1435  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1436 
1437  // Create the attribute quantities for the shuffle / mask / shift ops.
1438  for (auto &srcEltRangeList : enumerator.sourceElementRanges) {
1439  int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1440  ? srcEltRangeList[shuffleIdx].sourceElementIdx
1441  : 0;
1442  shuffles.push_back(sourceElement);
1443 
1444  int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1445  ? srcEltRangeList[shuffleIdx].sourceBitBegin
1446  : 0;
1447  int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1448  ? srcEltRangeList[shuffleIdx].sourceBitEnd
1449  : 0;
1450  IntegerAttr mask = IntegerAttr::get(
1451  shuffledElementType,
1452  llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1453  bitLo, bitHi));
1454  masks.push_back(mask);
1455 
1456  int64_t shiftRight = bitLo;
1457  shiftRightAmounts.push_back(
1458  IntegerAttr::get(shuffledElementType, shiftRight));
1459 
1460  int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1461  shiftLeftAmounts.push_back(
1462  IntegerAttr::get(shuffledElementType, shiftLeft));
1463  }
1464 
1465  result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1466  }
1467  return result;
1468 }
1469 
1470 Value BitCastRewriter::genericRewriteStep(
1471  PatternRewriter &rewriter, Location loc, Value initialValue,
1472  Value runningResult, const BitCastRewriter::Metadata &metadata) {
1473  // Create vector.shuffle from the metadata.
1474  auto shuffleOp = rewriter.create<vector::ShuffleOp>(
1475  loc, initialValue, initialValue, metadata.shuffles);
1476 
1477  // Intersect with the mask.
1478  VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1479  auto constOp = rewriter.create<arith::ConstantOp>(
1480  loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
1481  Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
1482 
1483  // Align right on 0.
1484  auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
1485  loc,
1486  DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
1487  Value shiftedRight =
1488  rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
1489 
1490  // Shift bits left into their final position.
1491  auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
1492  loc,
1493  DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
1494  Value shiftedLeft =
1495  rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
1496 
1497  runningResult =
1498  runningResult
1499  ? rewriter.create<arith::OrIOp>(loc, runningResult, shiftedLeft)
1500  : shiftedLeft;
1501 
1502  return runningResult;
1503 }
1504 
1505 /// Bitcasts the aligned `subByteVec` vector to a vector of i8.
1506 /// Where aligned means it satisfies the alignedConversionPreconditions.
1507 ///
1508 /// Example:
1509 /// vector<16x16xi2> -> vector<16x4xi8>
1510 /// vector<16x16xi4> -> vector<16x8xi8>
1512  Value subByteVec) {
1513  auto srcVecType = cast<VectorType>(subByteVec.getType());
1514  int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1515  assert(8 % srcBitwidth == 0 &&
1516  "Unsupported sub-byte type (not a divisor of i8)");
1517  int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1518  SmallVector<int64_t> vecShape(srcVecType.getShape());
1519  // Adjust last dimension of the vector, so the total size remains the same.
1520  vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1521  auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
1522  return rewriter.create<vector::BitCastOp>(loc, i8VecType, subByteVec);
1523 }
1524 
1525 /// Extracts a signed N-bit sequence from each element of a vector of bytes,
1526 /// starting at the specified bit index.
1527 /// The `bitIdx` starts at 0 from the LSB and moves to the left.
1528 ///
1529 /// Example for a single element:
1530 /// Extract numBits=2 starting at bitIdx=2
1531 /// src = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0]
1532 /// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1533 /// target = [. . . . ^ ^ . .]
1534 ///
1535 /// The target sequence is [11](decimal=-1) as signed 2-bit integer.
1536 /// So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer.
1537 ///
1538 /// src = [01 01 11 10]
1539 /// shl = arith.shl(src, 4) -> [11 10 00 00]
1540 /// result = arith.shrsi(shl, 6) -> [11 11 11 11]
1542  Location loc, Value src,
1543  int bitIdx, int numBits) {
1544  auto srcType = cast<VectorType>(src.getType());
1545  Value shl = src;
1546  int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1547  assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1548  "Invalid bitIdx range");
1549  if (bitsToShiftLeft != 0) {
1550  Value shiftLeftValues = rewriter.create<arith::ConstantOp>(
1551  loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
1552  shl = rewriter.create<arith::ShLIOp>(loc, src, shiftLeftValues);
1553  }
1554 
1555  int8_t bitsToShiftRight = 8 - numBits;
1556  Value shiftRightValues = rewriter.create<arith::ConstantOp>(
1557  loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1558  Value shr = rewriter.create<arith::ShRSIOp>(loc, shl, shiftRightValues);
1559  return shr;
1560 }
1561 
1562 /// Extracts an unsigned N-bit sequence from each element of a vector of bytes,
1563 /// starting at the specified bit index.
1564 /// The `bitIdx` starts at 0 from the LSB and moves to the left.
1565 ///
1566 /// Example for a single element:
1567 /// Extract numBits=2 starting at bitIdx=2
1568 /// src = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0]
1569 /// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1570 /// target = [. . . . ^ ^ . .]
1571 ///
1572 /// The target sequence is [10](decimal=2) as unsigned 2-bit integer.
1573 /// So the result should be [00 00 00 10](decimal=2) as unsigned 8-bit integer.
1574 ///
1575 /// src = [01 01 10 10]
1576 /// mask = [00 00 00 11]
1577 /// shr = arith.shrui(src, 2) = [00 01 01 10]
1578 /// result = arith.andi(shr, mask) = [00 00 00 10]
1579 /// NOTE: Similarly to extractNBitsPerByteAndSignExtendToI8, this could be
1580 /// achieved by using arith::ShLIOp + arith::ShRUIOp instead of the masking.
1581 /// However, by using arith::ShRUIOp + arith::AndIOp, we are eliminating shift
1582 /// left when the index is 0.
1584  Location loc, Value src,
1585  int bitIdx, int numBits) {
1586  assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1587  "Invalid bitIdx range");
1588  auto srcType = cast<VectorType>(src.getType());
1589  int8_t bitsToShiftRight = bitIdx;
1590  Value shr = src;
1591  if (bitsToShiftRight != 0) {
1592  Value shiftRightValues = rewriter.create<arith::ConstantOp>(
1593  loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1594  shr = rewriter.create<arith::ShRUIOp>(loc, src, shiftRightValues);
1595  }
1596  if (bitIdx + numBits == 8) {
1597  return shr;
1598  }
1599  uint8_t lowBitsMask = (1 << numBits) - 1;
1600  Value lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
1601  loc, DenseElementsAttr::get(srcType, lowBitsMask));
1602  return rewriter.create<arith::AndIOp>(loc, shr, lowBitsMaskValues);
1603 }
1604 
1606  std::function<Value(PatternRewriter &, Location, Value, int, int)>;
1607 
1608 /// Rewrite the i4 -> i8 extension into a sequence of shuffles and
1609 /// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1611  Value srcValue, const ExtractNBitsFn &extFn) {
1612  [[maybe_unused]] auto srcVecType = cast<VectorType>(srcValue.getType());
1613  assert(srcVecType.getElementType().isSignlessInteger(4) &&
1614  "Expected i4 type");
1615 
1616  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1617  Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1618 
1619  // 2. Extend i4 elements to i8 elements. Low i4 elemens of each
1620  // byte are place in one vector and the high i4 elements in another vector.
1621  Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1622  Value high = extFn(rewriter, loc, i8Vector, 4, 4);
1623 
1624  // 3. Interleave low and high i8 elements.
1625  return rewriter.create<vector::InterleaveOp>(loc, low, high);
1626 }
1627 
1628 /// Rewrite the i2 -> i8 extension into a sequence of shuffles and
1629 /// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1631  Value srcValue, const ExtractNBitsFn &extFn) {
1632  [[maybe_unused]] VectorType srcVecType = cast<VectorType>(srcValue.getType());
1633  assert(srcVecType.getElementType().isSignlessInteger(2) &&
1634  "Expected i2 type");
1635 
1636  // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
1637  Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1638 
1639  // 2. Extract each i2 element
1640  // Positon 0 (bits 0-1)
1641  Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1642  // Position 1 (bits 2-3)
1643  Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1644  // Position 2 (bits 4-5)
1645  Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1646  // Position 3 (bits 6-7)
1647  Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1648 
1649  // 3. Interleave all 4 elements by first interleaving
1650  // even elements and then odd
1651  // vec0 = [0,0,0,0],...
1652  // vec1 = [1,1,1,1],...
1653  // vec2 = [2,2,2,2],...
1654  // vec3 = [3,3,3,3],...
1655  // 02 = [0,2,0,2,0,2,0,2],...
1656  // 13 = [1,3,1,3,1,3,1,3],...
1657  // 0213 = [0,1,2,3,...],...
1658  Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, vec0, vec2);
1659  Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, vec1, vec3);
1660  return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
1661 }
1662 
1663 /// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
1664 /// ops to avoid leaving LLVM to scramble with peephole optimizations.
1666  Value srcValue) {
1667  VectorType srcVecType = cast<VectorType>(srcValue.getType());
1668  assert(srcVecType.getElementType().isSignlessInteger(8) &&
1669  "Expected i8 type");
1670 
1671  // 1. De-interleave low and high i8 elements.
1672  auto deinterleaveOp = rewriter.create<vector::DeinterleaveOp>(loc, srcValue);
1673 
1674  // 2. Zero out the upper side of each low i8 element.
1675  constexpr int8_t i8LowBitMask = 0x0F;
1676  VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1677  Value zeroOutMask = rewriter.create<arith::ConstantOp>(
1678  loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask));
1679  Value zeroOutLow = rewriter.create<arith::AndIOp>(
1680  loc, deinterleaveOp.getRes1(), zeroOutMask);
1681 
1682  // 3. Move high i4 values to upper side of the byte.
1683  constexpr int8_t bitsToShift = 4;
1684  auto shiftValues = rewriter.create<arith::ConstantOp>(
1685  loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
1686  Value shlHigh = rewriter.create<arith::ShLIOp>(loc, deinterleaveOp.getRes2(),
1687  shiftValues);
1688 
1689  // 4. Merge high and low i4 values.
1690  auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
1691 
1692  // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
1693  auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
1694  return rewriter.create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
1695 }
1696 
1697 namespace {
1698 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
1699 /// advantage of high-level information to avoid leaving LLVM to scramble with
1700 /// peephole optimizations.
1701 struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
1703 
1704  LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1705  PatternRewriter &rewriter) const override {
1706  // The source must be a trunc op.
1707  auto truncOp =
1708  bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1709  if (!truncOp)
1710  return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
1711 
1712  // Set up the BitCastRewriter and verify the precondition.
1713  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1714  VectorType targetVectorType = bitCastOp.getResultVectorType();
1715  BitCastRewriter bcr(sourceVectorType, targetVectorType);
1716  if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1717  return failure();
1718 
1719  // Perform the rewrite.
1720  Value truncValue = truncOp.getIn();
1721  auto shuffledElementType =
1722  cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
1723  Value runningResult;
1724  for (const BitCastRewriter ::Metadata &metadata :
1725  bcr.precomputeMetadata(shuffledElementType)) {
1726  runningResult = bcr.genericRewriteStep(
1727  rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1728  }
1729 
1730  // Finalize the rewrite.
1731  bool narrowing = targetVectorType.getElementTypeBitWidth() <=
1732  shuffledElementType.getIntOrFloatBitWidth();
1733  if (narrowing) {
1734  if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1735  rewriter.replaceOp(bitCastOp, runningResult);
1736  } else {
1737  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
1738  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1739  }
1740  } else {
1741  if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1742  rewriter.replaceOp(bitCastOp, runningResult);
1743  } else {
1744  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1745  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1746  }
1747  }
1748 
1749  return success();
1750  }
1751 };
1752 } // namespace
1753 
1754 //===----------------------------------------------------------------------===//
1755 // RewriteExtOfBitCast
1756 //===----------------------------------------------------------------------===//
1757 
1758 namespace {
1759 /// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that
1760 /// take advantage of high-level information to avoid leaving LLVM to scramble
1761 /// with peephole optimizations.
1762 template <typename ExtOpType>
1763 struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
1765 
1766  RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
1767  : OpRewritePattern<ExtOpType>(context, benefit) {}
1768 
1769  LogicalResult matchAndRewrite(ExtOpType extOp,
1770  PatternRewriter &rewriter) const override {
1771  // The source must be a bitcast op.
1772  auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
1773  if (!bitCastOp)
1774  return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
1775 
1776  // Set up the BitCastRewriter and verify the precondition.
1777  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1778  VectorType targetVectorType = bitCastOp.getResultVectorType();
1779  BitCastRewriter bcr(sourceVectorType, targetVectorType);
1780  if (failed(bcr.commonPrecondition(
1781  rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
1782  return failure();
1783 
1784  // Perform the rewrite.
1785  Value runningResult;
1786  Value sourceValue = bitCastOp.getSource();
1787  auto shuffledElementType =
1788  cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
1789  for (const BitCastRewriter::Metadata &metadata :
1790  bcr.precomputeMetadata(shuffledElementType)) {
1791  runningResult = bcr.genericRewriteStep(
1792  rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
1793  }
1794 
1795  // Finalize the rewrite.
1796  bool narrowing =
1797  cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
1798  shuffledElementType.getIntOrFloatBitWidth();
1799  if (narrowing) {
1800  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
1801  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1802  } else {
1803  rewriter.replaceOpWithNewOp<ExtOpType>(
1804  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1805  }
1806 
1807  return success();
1808  }
1809 };
1810 
1811 /// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
1812 /// bitwise ops that take advantage of high-level information to avoid leaving
1813 /// LLVM to scramble with peephole optimizations. Templated to choose between
1814 /// signed and unsigned conversions.
1815 ///
1816 /// EXAMPLE 1 (signed):
1817 /// arith.extsi %in : vector<8xi4> to vector<8xi32>
1818 /// is rewriten as:
1819 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1820 /// %1 = arith.shli %0, 4 : vector<4xi8>
1821 /// %2 = arith.shrsi %1, 4 : vector<4xi8>
1822 /// %3 = arith.shrsi %0, 4 : vector<4xi8>
1823 /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1824 /// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
1825 ///
1826 /// EXAMPLE 2 (fp):
1827 /// arith.sitofp %in : vector<8xi4> to vector<8xf32>
1828 /// is rewriten as:
1829 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1830 /// %1 = arith.shli %0, 4 : vector<4xi8>
1831 /// %2 = arith.shrsi %1, 4 : vector<4xi8>
1832 /// %3 = arith.shrsi %0, 4 : vector<4xi8>
1833 /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1834 /// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
1835 ///
1836 /// EXAMPLE 3 (unsigned):
1837 /// arith.extui %in : vector<8xi4> to vector<8xi32>
1838 /// is rewritten as:
1839 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1840 /// %1 = arith.andi %0, 15 : vector<4xi8>
1841 /// %2 = arith.shrui %0, 4 : vector<4xi8>
1842 /// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
1843 /// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
1844 ///
1845 template <typename ConversionOpType, bool isSigned>
1846 struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
1848 
1849  LogicalResult matchAndRewrite(ConversionOpType conversionOp,
1850  PatternRewriter &rewriter) const override {
1851  // Verify the preconditions.
1852  Value srcValue = conversionOp.getIn();
1853  VectorType srcVecType = dyn_cast<VectorType>(srcValue.getType());
1854  VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1855 
1856  if (failed(
1857  commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
1858  return failure();
1859 
1860  // Check general alignment preconditions.
1861  if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
1862  conversionOp)))
1863  return failure();
1864 
1865  // Perform the rewrite.
1866  Location loc = conversionOp.getLoc();
1867  const auto &extFn = isSigned ? extractNBitsPerByteAndSignExtendToI8
1869  Value subByteExt;
1870  switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
1871  case 2:
1872  subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn);
1873  break;
1874  case 4:
1875  subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn);
1876  break;
1877  default:
1878  return failure();
1879  }
1880 
1881  // Finalize the rewrite.
1882  rewriter.replaceOpWithNewOp<ConversionOpType>(
1883  conversionOp, conversionOp.getType(), subByteExt);
1884  return success();
1885  }
1886 };
1887 
1888 /// Rewrite the i8 -> i4 part of any truncation into a deinterleave and
1889 /// bitwise ops that take advantage of high-level information to avoid leaving
1890 /// LLVM to scramble with peephole optimizations.
1891 ///
1892 /// For example:
1893 /// arith.trunci %in : vector<8xi32> to vector<8xi4>
1894 ///
1895 /// is rewriten as:
1896 ///
1897 /// %cst = arith.constant dense<15> : vector<4xi8>
1898 /// %cst_0 = arith.constant dense<4> : vector<4xi8>
1899 /// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
1900 /// %2 = arith.andi %0, %cst : vector<4xi8>
1901 /// %3 = arith.shli %1, %cst_0 : vector<4xi8>
1902 /// %4 = arith.ori %2, %3 : vector<4xi8>
1903 /// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
1904 ///
1905 struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
1907 
1908  LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
1909  PatternRewriter &rewriter) const override {
1910  // Verify the preconditions.
1911  Value srcValue = truncOp.getIn();
1912  auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1913  auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
1914  if (!srcVecType || !dstVecType)
1915  return failure();
1916 
1917  if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
1918  return failure();
1919 
1920  // TODO: Add support for truncating to i2.
1921  if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
1922  return failure();
1923 
1924  // Check general alignment preconditions. We invert the src/dst type order
1925  // to reuse the existing precondition logic.
1926  if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
1927  truncOp)))
1928  return failure();
1929 
1930  // Create a new iX -> i8 truncation op.
1931  Location loc = truncOp.getLoc();
1932  auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
1933  Value i8TruncVal =
1934  rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue);
1935 
1936  // Rewrite the i8 -> i4 truncation part.
1937  Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
1938 
1939  // Finalize the rewrite.
1940  rewriter.replaceOp(truncOp, subByteTrunc);
1941  return success();
1942  }
1943 };
1944 
1945 /// Rewrite a sub-byte vector transpose into a sequence of instructions that
1946 /// perform the transpose on wider (byte) element types.
1947 ///
1948 /// EXAMPLE:
1949 /// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
1950 ///
1951 /// is rewritten as:
1952 ///
1953 /// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
1954 /// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
1955 /// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
1956 ///
1957 struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
1959 
1960  RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
1961  : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
1962 
1963  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
1964  PatternRewriter &rewriter) const override {
1965  // Precondition: sub-byte integer transpose.
1966  constexpr unsigned minNativeBitwidth = 8;
1967  VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
1968  if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
1969  srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
1970  return rewriter.notifyMatchFailure(transposeOp,
1971  "not a sub-byte transpose");
1972  }
1973 
1974  // Perform the rewrite.
1975  Location loc = transposeOp.getLoc();
1976  // Signed/unsigned interpretation shouldn't matter here as we are just
1977  // transposing the elements and truncating them back to the original size.
1978  // TODO: Use unsigned extension (more efficient) when emulation or backend
1979  // support is available.
1980  auto srcNativeVecType = srcSubByteVecType.cloneWith(
1981  std::nullopt, rewriter.getIntegerType(minNativeBitwidth));
1982  Value extOp = rewriter.create<arith::ExtSIOp>(loc, srcNativeVecType,
1983  transposeOp.getVector());
1984  Value newTranspose = rewriter.create<vector::TransposeOp>(
1985  loc, extOp, transposeOp.getPermutation());
1986  VectorType dstSubByteVecType = transposeOp.getResultVectorType();
1987  rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
1988  newTranspose);
1989  return success();
1990  }
1991 };
1992 
1993 } // namespace
1994 
1995 //===----------------------------------------------------------------------===//
1996 // Public Interface Definition
1997 //===----------------------------------------------------------------------===//
1998 
1999 // The emulated type is inferred from the converted memref type.
2001  const arith::NarrowTypeEmulationConverter &typeConverter,
2002  RewritePatternSet &patterns, bool disableAtomicRMW) {
2003 
2004  // Populate `vector.*` conversion patterns.
2005  // TODO: #119553 support atomicity
2006  patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
2007  ConvertVectorMaskedStore, ConvertVectorTransferRead>(
2008  typeConverter, patterns.getContext());
2009 
2010  // Populate `vector.*` store conversion patterns. The caller can choose
2011  // to avoid emitting atomic operations and reduce it to read-modify-write
2012  // sequence for stores if it is known there are no thread contentions.
2013  patterns.insert<ConvertVectorStore>(patterns.getContext(), disableAtomicRMW);
2014 }
2015 
2018  // TODO: Document what the emulated type is.
2019  patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
2020  RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
2021  benefit);
2022 
2023  // Patterns for aligned cases. We set higher priority as they are expected to
2024  // generate better performance for aligned cases.
2025  // The emulated type is always i8.
2026  patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
2027  RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
2028  RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
2029  benefit.getBenefit() + 1);
2030  // The emulated type is always i8.
2031  patterns
2032  .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
2033  RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
2034  patterns.getContext(), benefit.getBenefit() + 1);
2035 }
2036 
2037 // The emulated type is always i8.
2040  patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
2041 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, Location loc, VectorValue vector, int64_t extractOffset, int64_t sliceNumElements, int64_t insertOffset)
Extract sliceNumElements from source vector at extractOffset, and insert it into an empty vector at i...
std::function< Value(PatternRewriter &, Location, Value, int, int)> ExtractNBitsFn
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops to avoid leaving LLVM t...
TypedValue< MemRefType > MemRefValue
static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy)
Emulate a vector load for emulatedElemTy using containerElemTy
TypedValue< VectorType > VectorValue
static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, VectorType downcastType, VectorType upcastType, Value mask, Value trueValue, Value falseValue)
Downcast two values to downcastType, then select values based on mask, and casts the result to upcast...
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i4 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, Value source, int64_t frontOffset, int64_t subvecSize)
Extracts 1-D subvector from a 1-D vector.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector by overwriting the elements starting at offset.
static void atomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value storeIdx, VectorValue valueToStore, Value mask)
Emits memref.generic_atomic_rmw op to store a subbyte-sized value to a byte in linearizedMemref,...
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecType, VectorType dstType, Operation *op)
Verify that subByteVecType and dstType are aligned.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static void nonAtomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask)
Generate a non-atomic read-modify-write sequence for storing to the emulated type.
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, Value subByteVec)
Bitcasts the aligned subByteVec vector to a vector of i8.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, Value source, Value dest, OpFoldResult offset, int64_t numElementsToExtract)
Extracts a 1-D subvector from a 1-D source vector, with index at offset and size numElementsToExtract...
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value source, Value dest, OpFoldResult destOffsetVar, size_t length)
Inserts a 1-D subvector into a 1-D dest vector at index destOffsetVar.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts an unsigned N-bit sequence from each element of a vector of bytes, starting at the specified...
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i2 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts a signed N-bit sequence from each element of a vector of bytes, starting at the specified bi...
#define LDBG(X)
Base type for affine expression.
Definition: AffineExpr.h:68
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
IntegerType getI4Type()
Definition: Builders.cpp:57
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:277
IndexType getIndexType()
Definition: Builders.cpp:51
IntegerType getI8Type()
Definition: Builders.cpp:59
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_type_range getResultTypes()
Definition: Operation.h:428
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:114
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1198
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:52
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:320
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for rewriting vector operations over narrow types with ops over wider types.
void populateVectorTransposeNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for emulating a sub-byte vector transpose.
void populateVectorNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns, bool disableAtomicRMW=false)
Appends patterns for emulating vector operations over narrow types with ops over wider types.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:329
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition: MemRefUtils.h:50
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.