MLIR  22.0.0git
VectorEmulateNarrowType.cpp
Go to the documentation of this file.
1 //===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to emulate
10 // narrow types that are not supported by the target hardware, e.g. i4
11 // ("emulated type"), using wider types, e.g. i8 ("container type").
12 //
13 /// Currently, only power-of-two integer types are supported. These are
14 /// converted to wider integers that are either 8 bits wide or wider.
15 ///
16 /// TODO: Support for non-powers-of-two.
17 //===----------------------------------------------------------------------===//
18 
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/OpDefinition.h"
31 #include "mlir/IR/TypeUtilities.h"
32 #include "mlir/IR/Value.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/DebugLog.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <cstdint>
39 #include <optional>
40 
41 using namespace mlir;
42 
43 #define DEBUG_TYPE "vector-narrow-type-emulation"
44 
47 
48 //===----------------------------------------------------------------------===//
49 // Utils
50 //===----------------------------------------------------------------------===//
51 
52 /// Returns a compressed mask for the emulated vector. For example, when
53 /// emulating an eight-element `i8` vector with `i32` (i.e. when the source
54 /// elements span two dest elements), this method compresses `vector<8xi1>`
55 /// into `vector<2xi1>`.
56 ///
57 /// The compressed/output mask value is set iff any mask in the corresponding
58 /// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
59 /// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
60 /// following mask:
61 ///
62 /// %mask = [1, 1, 0, 0, 0, 0]
63 ///
64 /// will first be padded in the front with `numFrontPadElems` zeros, and zeros
65 /// will be added in the back to make the number of elements a multiple of
66 /// `numSrcElemsPerDest` (for easier computation). The resulting mask will be:
67 ///
68 /// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
69 ///
70 /// then it will return the following new compressed mask:
71 ///
72 /// %mask = [1, 1, 0, 0]
73 ///
74 /// NOTE: `numFrontPadElems` is assumed to be strictly smaller than
75 /// `numSrcElemsPerDest`.
76 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
77  Location loc, Value mask,
78  int numSrcElems,
79  int numSrcElemsPerDest,
80  int numFrontPadElems = 0) {
81 
82  assert(numFrontPadElems < numSrcElemsPerDest &&
83  "numFrontPadElems must be less than numSrcElemsPerDest");
84 
85  auto numDestElems =
86  (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
87  numSrcElemsPerDest;
88 
89  Operation *maskOp = mask.getDefiningOp();
91  // TODO: add support to `vector.splat`.
92  // Finding the mask creation operation.
93  while (maskOp &&
94  !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
95  maskOp)) {
96  if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
97  maskOp = extractOp.getVector().getDefiningOp();
98  extractOps.push_back(extractOp);
99  }
100  }
101 
102  if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
103  maskOp))
104  return failure();
105 
106  // Computing the "compressed" mask. All the emulation logic (i.e. computing
107  // new mask index) only happens on the last dimension of the vectors.
108  SmallVector<int64_t> maskShape(
109  cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
110  maskShape.back() = numDestElems;
111  auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type());
112  std::optional<Operation *> newMask =
114  .Case<vector::CreateMaskOp>(
115  [&](auto createMaskOp) -> std::optional<Operation *> {
116  OperandRange maskOperands = createMaskOp.getOperands();
117  // The `vector.create_mask` op creates a mask arrangement
118  // without any zeros at the front. Also, because
119  // `numFrontPadElems` is strictly smaller than
120  // `numSrcElemsPerDest`, the compressed mask generated by
121  // padding the original mask by `numFrontPadElems` will not
122  // have any zeros at the front as well.
123  AffineExpr s0;
124  bindSymbols(rewriter.getContext(), s0);
125  s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
126  OpFoldResult origIndex = getAsOpFoldResult(maskOperands.back());
128  rewriter, loc, s0, origIndex);
129  SmallVector<Value> newMaskOperands(maskOperands.drop_back());
130  newMaskOperands.push_back(
131  getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
132  return vector::CreateMaskOp::create(rewriter, loc, newMaskType,
133  newMaskOperands);
134  })
135  .Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
136  -> std::optional<Operation *> {
137  // Take the shape of mask, compress its trailing dimension:
138  SmallVector<int64_t> maskDimSizes(constantMaskOp.getMaskDimSizes());
139  int64_t &maskIndex = maskDimSizes.back();
140  maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
141  numSrcElemsPerDest);
142  return vector::ConstantMaskOp::create(rewriter, loc, newMaskType,
143  maskDimSizes);
144  })
145  .Case<arith::ConstantOp>([&](auto constantOp)
146  -> std::optional<Operation *> {
147  // TODO: Support multiple dimensions.
148  if (maskShape.size() != 1)
149  return std::nullopt;
150  // Rearrange the original mask values to cover the whole potential
151  // loading region. For example, in the case of using byte-size for
152  // emulation, given the following mask:
153  //
154  // %mask = [0, 1, 0, 1, 0, 0]
155  //
156  // With front offset of 1, the mask will be padded 0s in the front
157  // and back so that:
158  // 1. It is aligned with the effective loading bits
159  // 2. Its length is multiple of `numSrcElemPerDest` (and the total
160  // coverage size is mulitiple of bytes). The new mask will be like
161  // this before compressing:
162  //
163  // %new_mask = [0, 0, 1, 0, 1, 0, 0, 0]
164  auto originalMask =
165  cast<DenseIntElementsAttr>(constantOp.getValue());
166  SmallVector<bool> paddedMaskValues(numFrontPadElems, false);
167  paddedMaskValues.append(originalMask.template value_begin<bool>(),
168  originalMask.template value_end<bool>());
169  paddedMaskValues.resize(numDestElems * numSrcElemsPerDest, false);
170 
171  // Compressing by combining every `numSrcElemsPerDest` elements:
172  SmallVector<bool> compressedMaskValues;
173  for (size_t i = 0; i < paddedMaskValues.size();
174  i += numSrcElemsPerDest) {
175  bool combinedValue = false;
176  for (int j = 0; j < numSrcElemsPerDest; ++j) {
177  combinedValue |= paddedMaskValues[i + j];
178  }
179  compressedMaskValues.push_back(combinedValue);
180  }
181  return arith::ConstantOp::create(
182  rewriter, loc,
183  DenseElementsAttr::get(newMaskType, compressedMaskValues));
184  });
185 
186  if (!newMask)
187  return failure();
188 
189  while (!extractOps.empty()) {
190  newMask =
191  vector::ExtractOp::create(rewriter, loc, (*newMask)->getResults()[0],
192  extractOps.back().getMixedPosition());
193  extractOps.pop_back();
194  }
195 
196  return *newMask;
197 }
198 
199 /// Extracts 1-D subvector from a 1-D vector.
200 ///
201 /// Given the input rank-1 source vector, extracts `numElemsToExtract` elements
202 /// from `src`, starting at `offset`. The result is also a rank-1 vector:
203 ///
204 /// vector<numElemsToExtract x !elemType>
205 ///
206 /// (`!elType` is the element type of the source vector). As `offset` is a known
207 /// _static_ value, this helper hook emits `vector.extract_strided_slice`.
208 ///
209 /// EXAMPLE:
210 /// %res = vector.extract_strided_slice %src
211 /// { offsets = [offset], sizes = [numElemsToExtract], strides = [1] }
213  Value src, int64_t offset,
214  int64_t numElemsToExtract) {
215  auto vectorType = cast<VectorType>(src.getType());
216  assert(vectorType.getRank() == 1 && "expected source to be rank-1-D vector ");
217  assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
218  "subvector out of bounds");
219 
220  // When extracting all available elements, just use the source vector as the
221  // result.
222  if (vectorType.getNumElements() == numElemsToExtract)
223  return src;
224 
225  auto offsets = rewriter.getI64ArrayAttr({offset});
226  auto sizes = rewriter.getI64ArrayAttr({numElemsToExtract});
227  auto strides = rewriter.getI64ArrayAttr({1});
228 
229  auto resultVectorType =
230  VectorType::get({numElemsToExtract}, vectorType.getElementType());
231  return vector::ExtractStridedSliceOp::create(rewriter, loc, resultVectorType,
232  src, offsets, sizes, strides)
233  ->getResult(0);
234 }
235 
236 /// Inserts 1-D subvector into a 1-D vector.
237 ///
238 /// Inserts the input rank-1 source vector into the destination vector starting
239 /// at `offset`. As `offset` is a known _static_ value, this helper hook emits
240 /// `vector.insert_strided_slice`.
241 ///
242 /// EXAMPLE:
243 /// %res = vector.insert_strided_slice %src, %dest
244 /// {offsets = [%offset], strides [1]}
246  Value src, Value dest, int64_t offset) {
247  [[maybe_unused]] auto srcVecTy = cast<VectorType>(src.getType());
248  [[maybe_unused]] auto destVecTy = cast<VectorType>(dest.getType());
249  assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
250  "expected source and dest to be rank-1 vector types");
251 
252  // If overwritting the destination vector, just return the source.
253  if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
254  return src;
255 
256  auto offsets = rewriter.getI64ArrayAttr({offset});
257  auto strides = rewriter.getI64ArrayAttr({1});
258  return vector::InsertStridedSliceOp::create(rewriter, loc, destVecTy, src,
259  dest, offsets, strides);
260 }
261 
262 /// Extracts 1-D subvector from a 1-D vector.
263 ///
264 /// Given the input rank-1 source vector, extracts `numElemsToExtact` elements
265 /// from `src`, starting at `offset`. The result is also a rank-1 vector:
266 ///
267 /// vector<numElemsToExtact x !elType>
268 ///
269 /// (`!elType` is the element type of the source vector). As `offset` is assumed
270 /// to be a _dynamic_ SSA value, this helper method generates a sequence of
271 /// `vector.extract` + `vector.insert` pairs.
272 ///
273 /// EXAMPLE:
274 /// %v1 = vector.extract %src[%offset] : i2 from vector<8xi2>
275 /// %r1 = vector.insert %v1, %dest[0] : i2 into vector<3xi2>
276 /// %c1 = arith.constant 1 : index
277 /// %idx2 = arith.addi %offset, %c1 : index
278 /// %v2 = vector.extract %src[%idx2] : i2 from vector<8xi2>
279 /// %r2 = vector.insert %v2, %r1 [1] : i2 into vector<3xi2>
280 /// (...)
282  Value src, Value dest,
283  OpFoldResult offset,
284  int64_t numElemsToExtract) {
285  auto srcVecTy = cast<VectorType>(src.getType());
286  assert(srcVecTy.getRank() == 1 && "expected source to be rank-1-D vector ");
287  // NOTE: We are unable to take the offset into account in the following
288  // assert, hence its still possible that the subvector is out-of-bounds even
289  // if the condition is true.
290  assert(numElemsToExtract <= srcVecTy.getNumElements() &&
291  "subvector out of bounds");
292 
293  // When extracting all available elements, just use the source vector as the
294  // result.
295  if (srcVecTy.getNumElements() == numElemsToExtract)
296  return src;
297 
298  for (int i = 0; i < numElemsToExtract; ++i) {
299  Value extractLoc =
300  (i == 0) ? dyn_cast<Value>(offset)
301  : arith::AddIOp::create(
302  rewriter, loc, rewriter.getIndexType(),
303  dyn_cast<Value>(offset),
304  arith::ConstantIndexOp::create(rewriter, loc, i));
305  auto extractOp = vector::ExtractOp::create(rewriter, loc, src, extractLoc);
306  dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, i);
307  }
308  return dest;
309 }
310 
311 /// Inserts 1-D subvector into a 1-D vector.
312 ///
313 /// Inserts the input rank-1 source vector into the destination vector starting
314 /// at `offset`. As `offset` is assumed to be a _dynamic_ SSA value, this hook
315 /// uses a sequence of `vector.extract` + `vector.insert` pairs.
316 ///
317 /// EXAMPLE:
318 /// %v1 = vector.extract %src[0] : i2 from vector<8xi2>
319 /// %r1 = vector.insert %v1, %dest[%offset] : i2 into vector<3xi2>
320 /// %c1 = arith.constant 1 : index
321 /// %idx2 = arith.addi %offset, %c1 : index
322 /// %v2 = vector.extract %src[1] : i2 from vector<8xi2>
323 /// %r2 = vector.insert %v2, %r1 [%idx2] : i2 into vector<3xi2>
324 /// (...)
326  Value src, Value dest,
327  OpFoldResult offset,
328  int64_t numElemsToInsert) {
329  auto srcVecTy = cast<VectorType>(src.getType());
330  auto destVecTy = cast<VectorType>(dest.getType());
331  assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
332  "expected source and dest to be rank-1 vector types");
333  (void)srcVecTy;
334  (void)destVecTy;
335  assert(numElemsToInsert > 0 &&
336  "the number of elements to insert must be greater than 0");
337  // NOTE: We are unable to take the offset into account in the following
338  // assert, hence its still possible that the subvector is out-of-bounds even
339  // if the condition is true.
340  assert(numElemsToInsert <= destVecTy.getNumElements() &&
341  "subvector out of bounds");
342 
343  Value destOffsetVal = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
344  for (int64_t i = 0; i < numElemsToInsert; ++i) {
345  auto insertLoc =
346  i == 0 ? destOffsetVal
347  : arith::AddIOp::create(
348  rewriter, loc, rewriter.getIndexType(), destOffsetVal,
349  arith::ConstantIndexOp::create(rewriter, loc, i));
350  auto extractOp = vector::ExtractOp::create(rewriter, loc, src, i);
351  dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, insertLoc);
352  }
353  return dest;
354 }
355 
356 /// Emulate a vector load for `emulatedElemTy` using `containerElemTy`
357 ///
358 /// Specifically, use `containerElemTy` for loading a vector of
359 /// `emulatedElemTy`. The load location is given by `base` and
360 /// `linearizedIndices`, and the load size is given by
361 /// `numEmulatedElementsToLoad`.
363  Value base,
364  OpFoldResult linearizedIndices,
365  int64_t numContainerElemsToLoad,
366  Type emulatedElemTy,
367  Type containerElemTy) {
368  auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth() /
369  emulatedElemTy.getIntOrFloatBitWidth();
370  auto newLoad = vector::LoadOp::create(
371  rewriter, loc, VectorType::get(numContainerElemsToLoad, containerElemTy),
372  base, getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
373  return vector::BitCastOp::create(
374  rewriter, loc,
375  VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,
376  emulatedElemTy),
377  newLoad);
378 }
379 
380 /// Downcast two values to `downcastType`, then select values
381 /// based on `mask`, and casts the result to `upcastType`.
383  VectorType downcastType,
384  VectorType upcastType, Value mask,
385  Value trueValue, Value falseValue) {
386  assert(
387  downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
388  upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
389  "expected input and output number of bits to match");
390  if (trueValue.getType() != downcastType) {
391  trueValue =
392  vector::BitCastOp::create(builder, loc, downcastType, trueValue);
393  }
394  if (falseValue.getType() != downcastType) {
395  falseValue =
396  vector::BitCastOp::create(builder, loc, downcastType, falseValue);
397  }
398  Value selectedType =
399  arith::SelectOp::create(builder, loc, mask, trueValue, falseValue);
400  // Upcast the selected value to the new type.
401  return vector::BitCastOp::create(builder, loc, upcastType, selectedType);
402 }
403 
404 /// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
405 /// byte in `linearizedMemref`, with a mask. The `valueToStore` is a vector of
406 /// subbyte-sized elements, with size of 8 bits, and the mask is used to select
407 /// which elements to store.
408 ///
409 /// Inputs:
410 /// linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>)
411 /// storeIdx = 2
412 /// valueToStore = |3|3|3|3| : vector<4xi2>
413 /// mask = |0|0|1|1| : vector<4xi1>
414 ///
415 /// Result:
416 /// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
417 static void atomicRMW(OpBuilder &builder, Location loc,
418  MemRefValue linearizedMemref, Value storeIdx,
419  VectorValue valueToStore, Value mask) {
420  assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
421 
422  // Create an atomic load-modify-write region using
423  // `memref.generic_atomic_rmw`.
424  auto atomicOp = memref::GenericAtomicRMWOp::create(
425  builder, loc, linearizedMemref, ValueRange{storeIdx});
426  Value origValue = atomicOp.getCurrentValue();
427 
428  OpBuilder::InsertionGuard guard(builder);
429  builder.setInsertionPointToStart(atomicOp.getBody());
430 
431  // Load the original value from memory, and cast it to the original element
432  // type.
433  auto oneElemVecType = VectorType::get({1}, origValue.getType());
434  Value origVecValue = vector::FromElementsOp::create(
435  builder, loc, oneElemVecType, ValueRange{origValue});
436 
437  // Construct the final masked value and yield it.
438  Value maskedValue =
439  downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
440  oneElemVecType, mask, valueToStore, origVecValue);
441  auto scalarMaskedValue =
442  vector::ExtractOp::create(builder, loc, maskedValue, 0);
443  memref::AtomicYieldOp::create(builder, loc, scalarMaskedValue);
444 }
445 
446 /// Generate a non-atomic read-modify-write sequence for storing to the emulated
447 /// type. It has similar logic to `atomicRMWStore`, but without atomicity.
448 static void nonAtomicRMW(OpBuilder &builder, Location loc,
449  MemRefValue linearizedMemref, Value linearizedIndex,
450  VectorValue valueToStore, Value mask) {
451  assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
452 
453  auto oneElemVecType =
454  VectorType::get({1}, linearizedMemref.getType().getElementType());
455  Value origVecValue =
456  vector::LoadOp::create(builder, loc, oneElemVecType, linearizedMemref,
457  ValueRange{linearizedIndex});
458  origVecValue = vector::BitCastOp::create(builder, loc, valueToStore.getType(),
459  origVecValue);
460 
461  Value maskedValue =
462  downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
463  oneElemVecType, mask, valueToStore, origVecValue);
464  vector::StoreOp::create(builder, loc, maskedValue, linearizedMemref,
465  linearizedIndex);
466 }
467 
468 /// Extract `sliceNumElements` from source `vector` at `extractOffset`,
469 /// and insert it into an empty vector at `insertOffset`.
470 /// Inputs:
471 /// vec_in = |0|1|2|3| : vector<4xi2>
472 /// extractOffset = 1
473 /// sliceNumElements = 2
474 /// insertOffset = 2
475 /// Output:
476 /// vec_out = |0|0|1|2| : vector<4xi2>
478  Location loc, VectorValue vector,
479  int64_t extractOffset,
480  int64_t sliceNumElements,
481  int64_t insertOffset) {
482  assert(vector.getType().getRank() == 1 && "expected 1-D vector");
483  auto vectorElementType = vector.getType().getElementType();
484  // TODO: update and use `alignedConversionPrecondition` in the place of
485  // these asserts.
486  assert(
487  sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
488  "sliceNumElements * vector element size must be less than or equal to 8");
489  assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
490  "vector element must be a valid sub-byte type");
491  auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
492  auto emptyByteVector = arith::ConstantOp::create(
493  rewriter, loc,
494  VectorType::get({emulatedPerContainerElem}, vectorElementType),
495  rewriter.getZeroAttr(
496  VectorType::get({emulatedPerContainerElem}, vectorElementType)));
497  auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
498  extractOffset, sliceNumElements);
499  return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
500  insertOffset);
501 }
502 
503 namespace {
504 
505 //===----------------------------------------------------------------------===//
506 // ConvertVectorStore
507 //===----------------------------------------------------------------------===//
508 
509 // Emulate `vector.store` using a multi-byte container type.
510 //
511 // The container type is obtained through Op adaptor and would normally be
512 // generated via `NarrowTypeEmulationConverter`.
513 //
514 // EXAMPLE 1
515 // (aligned store of i4, emulated using i8 as the container type)
516 //
517 // vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4>
518 //
519 // is rewritten as:
520 //
521 // %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8>
522 // vector.store %src_bitcast, %dest_bitcast[%idx]
523 // : memref<16xi8>, vector<4xi8>
524 //
525 // EXAMPLE 2
526 // (unaligned store of i2, emulated using i8 as the container type)
527 //
528 // vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
529 //
530 // The i2 store is emulated through 2 x RMW sequences. The destination i2 memref
531 // is modelled using 3 bytes:
532 //
533 // Byte 0 Byte 1 Byte 2
534 // +----------+----------+----------+
535 // | oooooooo | ooooNNNN | NNoooooo |
536 // +----------+----------+----------+
537 //
538 // N - (N)ew entries (i.e. to be overwritten by vector.store)
539 // o - (o)ld entries (to be preserved)
540 //
541 // For the generated output in the non-atomic case, see:
542 // * @vector_store_i2_const_index_two_partial_stores`
543 // in:
544 // * "vector-emulate-narrow-type-unaligned-non-atomic.mlir".
545 //
546 // NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to
547 // `false` to generate non-atomic RMW sequences.
548 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
550 
551  ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW)
552  : OpConversionPattern<vector::StoreOp>(context),
553  disableAtomicRMW(disableAtomicRMW) {}
554 
555  LogicalResult
556  matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
557  ConversionPatternRewriter &rewriter) const override {
558 
559  // See #115653
560  if (op.getValueToStore().getType().getRank() != 1)
561  return rewriter.notifyMatchFailure(op,
562  "only 1-D vectors are supported ATM");
563 
564  auto loc = op.getLoc();
565 
566  auto valueToStore = cast<VectorValue>(op.getValueToStore());
567  auto containerElemTy =
568  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
569  Type emulatedElemTy = op.getValueToStore().getType().getElementType();
570  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
571  int containerBits = containerElemTy.getIntOrFloatBitWidth();
572 
573  // Check per-element alignment.
574  if (containerBits % emulatedBits != 0) {
575  return rewriter.notifyMatchFailure(
576  op, "impossible to pack emulated elements into container elements "
577  "(bit-wise misalignment)");
578  }
579  int emulatedPerContainerElem = containerBits / emulatedBits;
580 
581  // Adjust the number of elements to store when emulating narrow types.
582  // Here only the 1-D vector store is considered, and the N-D memref types
583  // should be linearized.
584  // For example, to emulate i4 to i8, the following op:
585  //
586  // vector.store %arg1, %0[%arg2, %arg3] : memref<4x8xi4>, vector<8xi4>
587  //
588  // can be replaced with
589  //
590  // %bitcast = vector.bitcast %arg1 : vector<8xi4> to vector<4xi8>
591  // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
592  // vector<4xi8>
593 
594  auto origElements = valueToStore.getType().getNumElements();
595  // Note, per-element-alignment was already verified above.
596  bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
597  // Do the trailing dim for source and destination match? If yes, then the
598  // corresponding index must be 0.
599  // FIXME: There's no way to tell for dynamic shapes, so we should bail out.
600  // However, that makes some tests fail, so we need to audit first.
601  auto trailingDim = op.getBase().getType().getShape().back();
602  bool trailingDimsMatch =
603  ShapedType::isDynamic(trailingDim) || trailingDim == origElements;
604 
605  auto stridedMetadata =
606  memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
607 
608  // FIXME: ATM, we do not test cases where offsets, sizes, or strides are
609  // non-zero. As such, this is not needed.
610  OpFoldResult linearizedIndices;
611  memref::LinearizedMemRefInfo linearizedInfo;
612  std::tie(linearizedInfo, linearizedIndices) =
614  rewriter, loc, emulatedBits, containerBits,
615  stridedMetadata.getConstifiedMixedOffset(),
616  stridedMetadata.getConstifiedMixedSizes(),
617  stridedMetadata.getConstifiedMixedStrides(),
618  getAsOpFoldResult(adaptor.getIndices()));
619 
620  std::optional<int64_t> foldedNumFrontPadElems =
621  (isDivisibleInSize && trailingDimsMatch)
622  ? 0
623  : getConstantIntValue(linearizedInfo.intraDataOffset);
624 
625  if (!foldedNumFrontPadElems) {
626  return rewriter.notifyMatchFailure(
627  op, "subbyte store emulation: dynamic front padding size is "
628  "not yet implemented");
629  }
630 
631  auto memrefBase = cast<MemRefValue>(adaptor.getBase());
632 
633  // RMWs are not needed when:
634  // * no _partial_ stores are required.
635  // A partial store is defined as a store in which only a part of the
636  // container element is overwritten, e.g.
637  //
638  // Dest before (8 bits)
639  // +----------+
640  // | 11000000 |
641  // +----------+
642  //
643  // Dest after storing 0xF at offset 4 (in bits)
644  // +----------+
645  // | 11001111 |
646  // +----------+
647  //
648  // At a higher level, this translats to:
649  // 1. The source vector size (in bits) is a multiple of byte size.
650  // 2. The address of the store is aligned to the container type width
651  // boundary.
652  //
653  // EXAMPLE 1:
654  // Requires partial store:
655  // vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
656  //
657  // EXAMPLE 2:
658  // Does not require a partial store:
659  // vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
660  //
661  // TODO: Take linearizedInfo.linearizedOffset into account. This is
662  // currently not needed/used/exercised as all our tests set offset to 0.
663  bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
664 
665  if (!emulationRequiresPartialStores) {
666  // Basic case: storing full bytes.
667  auto numElements = origElements / emulatedPerContainerElem;
668  auto bitCast = vector::BitCastOp::create(
669  rewriter, loc, VectorType::get(numElements, containerElemTy),
670  op.getValueToStore());
671  rewriter.replaceOpWithNewOp<vector::StoreOp>(
672  op, bitCast.getResult(), memrefBase,
673  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
674  return success();
675  }
676 
677  // Next, handle the case when sub-byte read-modify-write
678  // sequences are needed to emulate a vector store.
679  // Here is an example:
680  //
681  // Vector to store: vector<7xi2>
682  // Value to store: 11 11 11 11 11 11 11 (all ones)
683  //
684  // Destination: memref<12xi2>
685  // Store offset: 2 (i.e. 4 bits into the 1st emulated byte).
686  //
687  // Input MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
688  //
689  // Destination memref before:
690  //
691  // Byte 0 Byte 1 Byte 2
692  // +----------+----------+----------+
693  // | 00000000 | 00000000 | 00000000 |
694  // +----------+----------+----------+
695  //
696  // Destination memref after:
697  //
698  // Byte 0 Byte 1 Byte 2
699  // +----------+----------+----------+
700  // | 00001111 | 11111111 | 11000000 |
701  // +----------+----------+----------+
702  //
703  // Note, stores to Byte 1 are "full-width" and hence don't require RMW (no
704  // need for atomicity). Stores to Bytes 0 and Byte 2 are "partial", hence
705  // requiring RMW access (atomicity is required).
706 
707  // The index into the target memref we are storing to.
708  Value currentDestIndex =
709  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
710  // The index into the source vector we are currently processing.
711  auto currentSourceIndex = 0;
712 
713  // Build a mask used for rmw.
714  auto subWidthStoreMaskType =
715  VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type());
716 
717  auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
718 
719  // 1. Partial width store for the leading byte.
720  // When the store address is not aligned to emulated width boundary, deal
721  // with the unaligned part so that the rest elements are aligned to width
722  // boundary.
723  auto frontSubWidthStoreElem =
724  (emulatedPerContainerElem - *foldedNumFrontPadElems) %
725  emulatedPerContainerElem;
726  if (frontSubWidthStoreElem > 0) {
727  SmallVector<bool> frontMaskValues(emulatedPerContainerElem, false);
728  if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
729  std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
730  origElements, true);
731  frontSubWidthStoreElem = origElements;
732  } else {
733  std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
734  *foldedNumFrontPadElems, true);
735  }
736  auto frontMask = arith::ConstantOp::create(
737  rewriter, loc,
738  DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
739 
740  currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
741  auto value =
742  extractSliceIntoByte(rewriter, loc, valueToStore, 0,
743  frontSubWidthStoreElem, *foldedNumFrontPadElems);
744 
745  storeFunc(rewriter, loc, memrefBase, currentDestIndex,
746  cast<VectorValue>(value), frontMask.getResult());
747  }
748 
749  if (currentSourceIndex >= origElements) {
750  rewriter.eraseOp(op);
751  return success();
752  }
753 
754  // Increment the destination index by 1 to align to the emulated width
755  // boundary.
756  auto constantOne = arith::ConstantIndexOp::create(rewriter, loc, 1);
757  currentDestIndex = arith::AddIOp::create(
758  rewriter, loc, rewriter.getIndexType(), currentDestIndex, constantOne);
759 
760  // 2. Full width store for the inner output bytes.
761  // After the previous step, the store address is aligned to the emulated
762  // width boundary.
763  int64_t fullWidthStoreSize =
764  (origElements - currentSourceIndex) / emulatedPerContainerElem;
765  int64_t numNonFullWidthElements =
766  fullWidthStoreSize * emulatedPerContainerElem;
767  if (fullWidthStoreSize > 0) {
768  auto fullWidthStorePart = staticallyExtractSubvector(
769  rewriter, loc, valueToStore, currentSourceIndex,
770  numNonFullWidthElements);
771 
772  auto originType = cast<VectorType>(fullWidthStorePart.getType());
773  auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
774  auto storeType = VectorType::get(
775  {originType.getNumElements() / emulatedPerContainerElem},
776  memrefElemType);
777  auto bitCast = vector::BitCastOp::create(rewriter, loc, storeType,
778  fullWidthStorePart);
779  vector::StoreOp::create(rewriter, loc, bitCast.getResult(), memrefBase,
780  currentDestIndex);
781 
782  currentSourceIndex += numNonFullWidthElements;
783  currentDestIndex = arith::AddIOp::create(
784  rewriter, loc, rewriter.getIndexType(), currentDestIndex,
785  arith::ConstantIndexOp::create(rewriter, loc, fullWidthStoreSize));
786  }
787 
788  // 3. Partial width store for the trailing output byte.
789  // It is needed when the residual length is smaller than the emulated width,
790  // which is not covered in step 2 above.
791  auto remainingElements = origElements - currentSourceIndex;
792  if (remainingElements != 0) {
793  auto subWidthStorePart =
794  extractSliceIntoByte(rewriter, loc, cast<VectorValue>(valueToStore),
795  currentSourceIndex, remainingElements, 0);
796 
797  // Generate back mask.
798  auto maskValues = SmallVector<bool>(emulatedPerContainerElem, 0);
799  std::fill_n(maskValues.begin(), remainingElements, 1);
800  auto backMask = arith::ConstantOp::create(
801  rewriter, loc,
802  DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
803 
804  storeFunc(rewriter, loc, memrefBase, currentDestIndex,
805  cast<VectorValue>(subWidthStorePart), backMask.getResult());
806  }
807 
808  rewriter.eraseOp(op);
809  return success();
810  }
811 
812 private:
813  const bool disableAtomicRMW;
814 };
815 
816 //===----------------------------------------------------------------------===//
817 // ConvertVectorMaskedStore
818 //===----------------------------------------------------------------------===//
819 
820 // TODO: Document-me
821 struct ConvertVectorMaskedStore final
822  : OpConversionPattern<vector::MaskedStoreOp> {
824 
825  LogicalResult
826  matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
827  ConversionPatternRewriter &rewriter) const override {
828 
829  // See #115653
830  if (op.getValueToStore().getType().getRank() != 1)
831  return rewriter.notifyMatchFailure(op,
832  "only 1-D vectors are supported ATM");
833 
834  auto loc = op.getLoc();
835  auto containerElemTy =
836  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
837  Type emulatedElemTy = op.getValueToStore().getType().getElementType();
838  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
839  int containerBits = containerElemTy.getIntOrFloatBitWidth();
840 
841  // Check per-element alignment.
842  if (containerBits % emulatedBits != 0) {
843  return rewriter.notifyMatchFailure(
844  op, "impossible to pack emulated elements into container elements "
845  "(bit-wise misalignment)");
846  }
847 
848  int emulatedPerContainerElem = containerBits / emulatedBits;
849  int origElements = op.getValueToStore().getType().getNumElements();
850  if (origElements % emulatedPerContainerElem != 0)
851  return failure();
852 
853  auto stridedMetadata =
854  memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
855  OpFoldResult linearizedIndicesOfr;
856  memref::LinearizedMemRefInfo linearizedInfo;
857  std::tie(linearizedInfo, linearizedIndicesOfr) =
859  rewriter, loc, emulatedBits, containerBits,
860  stridedMetadata.getConstifiedMixedOffset(),
861  stridedMetadata.getConstifiedMixedSizes(),
862  stridedMetadata.getConstifiedMixedStrides(),
863  getAsOpFoldResult(adaptor.getIndices()));
864  Value linearizedIndices =
865  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
866 
867  // Load the whole data and use arith.select to handle the corner cases.
868  //
869  // As an example, for this masked store of i4 values:
870  //
871  // vector.maskedstore %0[%c0, %c0], %mask, %val_to_store
872  //
873  // and given these input values:
874  //
875  // %mask = [0, 1, 1, 1, 1, 0, 0, 0] (8 * i1)
876  // %0[%c0, %c0] =
877  // [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
878  // %val_to_store =
879  // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] (8 * i4)
880  //
881  // we'll have the following i4 output:
882  //
883  // expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8]
884  //
885  // Emulating the above using i8 will give:
886  //
887  // %compressed_mask = [1, 1, 1, 0] (4 * i1)
888  // %maskedload = [0x12, 0x34, 0x56, 0x00] (4 * i8)
889  // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] (8 * i4)
890  // %select_using_shifted_mask =
891  // [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0] (8 * i4)
892  // %packed_data = [0x1A, 0xBC, 0xD6, 0x00] (4 * i8)
893  //
894  // Using the compressed mask to store %packed_data results in expected
895  // output.
896  //
897  // FIXME: Make an example based on the comment above work (see #115460 for
898  // reproducer).
899  FailureOr<Operation *> newMask = getCompressedMaskOp(
900  rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
901  if (failed(newMask))
902  return failure();
903 
904  auto numElements = (origElements + emulatedPerContainerElem - 1) /
905  emulatedPerContainerElem;
906  auto newType = VectorType::get(numElements, containerElemTy);
907  auto passThru = arith::ConstantOp::create(rewriter, loc, newType,
908  rewriter.getZeroAttr(newType));
909 
910  auto newLoad = vector::MaskedLoadOp::create(
911  rewriter, loc, newType, adaptor.getBase(), linearizedIndices,
912  newMask.value()->getResult(0), passThru);
913 
914  auto newBitCastType =
915  VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
916  Value valueToStore =
917  vector::BitCastOp::create(rewriter, loc, newBitCastType, newLoad);
918  valueToStore = arith::SelectOp::create(rewriter, loc, op.getMask(),
919  op.getValueToStore(), valueToStore);
920  valueToStore =
921  vector::BitCastOp::create(rewriter, loc, newType, valueToStore);
922 
923  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
924  op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
925  valueToStore);
926  return success();
927  }
928 };
929 
930 //===----------------------------------------------------------------------===//
931 // ConvertVectorLoad
932 //===----------------------------------------------------------------------===//
933 
934 // TODO: Document-me
935 struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
937 
938  LogicalResult
939  matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
940  ConversionPatternRewriter &rewriter) const override {
941 
942  // See #115653
943  if (op.getVectorType().getRank() != 1)
944  return rewriter.notifyMatchFailure(op,
945  "only 1-D vectors are supported ATM");
946 
947  auto loc = op.getLoc();
948  auto containerElemTy =
949  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
950  Type emulatedElemTy = op.getType().getElementType();
951  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
952  int containerBits = containerElemTy.getIntOrFloatBitWidth();
953 
954  // Check per-element alignment.
955  if (containerBits % emulatedBits != 0) {
956  return rewriter.notifyMatchFailure(
957  op, "impossible to pack emulated elements into container elements "
958  "(bit-wise misalignment)");
959  }
960  int emulatedPerContainerElem = containerBits / emulatedBits;
961 
962  // Adjust the number of elements to load when emulating narrow types,
963  // and then cast back to the original type with vector.bitcast op.
964  // Here only the 1-D vector load is considered, and the N-D memref types
965  // should be linearized.
966  // For example, to emulate i4 to i8, the following op:
967  //
968  // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
969  //
970  // can be replaced with
971  //
972  // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
973  // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
974  //
975  // There are cases where the number of elements to load is not byte-aligned,
976  // for example:
977  //
978  // %1 = vector.load %0[%c1, %c0] : memref<3x3xi2>, vector<3xi2>
979  //
980  // we will have to load extra bytes and extract the exact slice in between.
981  //
982  // %1 = vector.load %0[%c2] : memref<3xi8>, vector<2xi8>
983  // %2 = vector.bitcast %1 : vector<2xi8> to vector<8xi2>
984  // %3 = vector.extract_strided_slice %1 {offsets = [2], sizes = [3], strides
985  // = [1]}
986  // : vector<8xi2> to vector<3xi2>
987  //
988  // TODO: Currently the extract_strided_slice's attributes must be known at
989  // compile time as they must be constants.
990 
991  auto origElements = op.getVectorType().getNumElements();
992  // Note, per-element-alignment was already verified above.
993  bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
994 
995  auto stridedMetadata =
996  memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
997 
998  OpFoldResult linearizedIndices;
999  memref::LinearizedMemRefInfo linearizedInfo;
1000  std::tie(linearizedInfo, linearizedIndices) =
1002  rewriter, loc, emulatedBits, containerBits,
1003  stridedMetadata.getConstifiedMixedOffset(),
1004  stridedMetadata.getConstifiedMixedSizes(),
1005  stridedMetadata.getConstifiedMixedStrides(),
1006  getAsOpFoldResult(adaptor.getIndices()));
1007 
1008  std::optional<int64_t> foldedIntraVectorOffset =
1009  isDivisibleInSize ? 0
1010  : getConstantIntValue(linearizedInfo.intraDataOffset);
1011 
1012  // Always load enough elements which can cover the original elements.
1013  int64_t maxintraDataOffset =
1014  foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1015  auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
1016  emulatedPerContainerElem);
1017  Value result =
1018  emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
1019  numElements, emulatedElemTy, containerElemTy);
1020 
1021  if (!foldedIntraVectorOffset) {
1022  auto resultVector = arith::ConstantOp::create(
1023  rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1024  result = dynamicallyExtractSubVector(
1025  rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
1026  linearizedInfo.intraDataOffset, origElements);
1027  } else if (!isDivisibleInSize) {
1028  result = staticallyExtractSubvector(
1029  rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1030  }
1031  rewriter.replaceOp(op, result);
1032  return success();
1033  }
1034 };
1035 
1036 //===----------------------------------------------------------------------===//
1037 // ConvertVectorMaskedLoad
1038 //===----------------------------------------------------------------------===//
1039 
1040 // TODO: Document-me
1041 struct ConvertVectorMaskedLoad final
1042  : OpConversionPattern<vector::MaskedLoadOp> {
1044 
1045  LogicalResult
1046  matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
1047  ConversionPatternRewriter &rewriter) const override {
1048  // See #115653
1049  if (op.getVectorType().getRank() != 1)
1050  return rewriter.notifyMatchFailure(op,
1051  "only 1-D vectors are supported ATM");
1052 
1053  auto loc = op.getLoc();
1054 
1055  auto containerElemTy =
1056  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1057  Type emulatedElemTy = op.getType().getElementType();
1058  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
1059  int containerBits = containerElemTy.getIntOrFloatBitWidth();
1060 
1061  // Check per-element alignment.
1062  if (containerBits % emulatedBits != 0) {
1063  return rewriter.notifyMatchFailure(
1064  op, "impossible to pack emulated elements into container elements "
1065  "(bit-wise misalignment)");
1066  }
1067  int emulatedPerContainerElem = containerBits / emulatedBits;
1068 
1069  // Adjust the number of elements to load when emulating narrow types,
1070  // and then cast back to the original type with vector.bitcast op.
1071  // For example, to emulate i4 to i8, the following op:
1072  //
1073  // %mask = vector.constant_mask [3] : vector<6xi1>
1074  // %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
1075  // memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
1076  //
1077  // can be replaced with
1078  //
1079  // %new_mask = vector.constant_mask [2] : vector<3xi1>
1080  // %new_pass_thru = vector.bitcast %pass_thru :
1081  // vector<6xi4> to vector<3xi8>
1082  // %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru :
1083  // memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8>
1084  // %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4>
1085  //
1086  // Since we are effectively loading 16 bits (2xi8) from the memref with the
1087  // new mask, while originally we only wanted to effectively load 12 bits
1088  // (3xi4) from the memref, we need to set the second half of the last i8
1089  // that was effectively loaded (i.e. the second i8) to %pass_thru.
1090  //
1091  // %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4>
1092  //
1093  // Given these input values:
1094  // %mask = [1, 1, 1, 0, 0, 0]
1095  // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
1096  // %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
1097  //
1098  // we'll have:
1099  //
1100  // expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
1101  //
1102  // %new_mask = [1, 1, 0]
1103  // %new_pass_thru = [0x78, 0x9A, 0xBC]
1104  // %1 = [0x12, 0x34, 0xBC]
1105  // %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC]
1106  // %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
1107  //
1108  // TODO: Currently, only the even number of elements loading is supported.
1109  // To deal with the odd number of elements, one has to extract the
1110  // subvector at the proper offset after bit-casting.
1111  auto origType = op.getVectorType();
1112  auto origElements = origType.getNumElements();
1113  // Note, per-element-alignment was already verified above.
1114  bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1115 
1116  auto stridedMetadata =
1117  memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1118  OpFoldResult linearizedIndices;
1119  memref::LinearizedMemRefInfo linearizedInfo;
1120  std::tie(linearizedInfo, linearizedIndices) =
1122  rewriter, loc, emulatedBits, containerBits,
1123  stridedMetadata.getConstifiedMixedOffset(),
1124  stridedMetadata.getConstifiedMixedSizes(),
1125  stridedMetadata.getConstifiedMixedStrides(),
1126  getAsOpFoldResult(adaptor.getIndices()));
1127 
1128  std::optional<int64_t> foldedIntraVectorOffset =
1129  isDivisibleInSize ? 0
1130  : getConstantIntValue(linearizedInfo.intraDataOffset);
1131 
1132  int64_t maxIntraDataOffset =
1133  foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1134  FailureOr<Operation *> newMask =
1135  getCompressedMaskOp(rewriter, loc, op.getMask(), origElements,
1136  emulatedPerContainerElem, maxIntraDataOffset);
1137  if (failed(newMask))
1138  return failure();
1139 
1140  Value passthru = op.getPassThru();
1141 
1142  auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1143  emulatedPerContainerElem);
1144  auto loadType = VectorType::get(numElements, containerElemTy);
1145  auto newBitcastType =
1146  VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
1147 
1148  auto emptyVector = arith::ConstantOp::create(
1149  rewriter, loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
1150  if (!foldedIntraVectorOffset) {
1151  passthru = dynamicallyInsertSubVector(
1152  rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
1153  origElements);
1154  } else if (!isDivisibleInSize) {
1155  passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
1156  *foldedIntraVectorOffset);
1157  }
1158  auto newPassThru =
1159  vector::BitCastOp::create(rewriter, loc, loadType, passthru);
1160 
1161  // Generating the new masked load.
1162  auto newLoad = vector::MaskedLoadOp::create(
1163  rewriter, loc, loadType, adaptor.getBase(),
1164  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
1165  newMask.value()->getResult(0), newPassThru);
1166 
1167  // Setting the part that originally was not effectively loaded from memory
1168  // to pass through.
1169  auto bitCast =
1170  vector::BitCastOp::create(rewriter, loc, newBitcastType, newLoad);
1171 
1172  Value mask = op.getMask();
1173  auto newSelectMaskType = VectorType::get(
1174  numElements * emulatedPerContainerElem, rewriter.getI1Type());
1175  // TODO: try to fold if op's mask is constant
1176  auto emptyMask =
1177  arith::ConstantOp::create(rewriter, loc, newSelectMaskType,
1178  rewriter.getZeroAttr(newSelectMaskType));
1179  if (!foldedIntraVectorOffset) {
1180  mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
1181  linearizedInfo.intraDataOffset,
1182  origElements);
1183  } else if (!isDivisibleInSize) {
1184  mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
1185  *foldedIntraVectorOffset);
1186  }
1187 
1188  Value result =
1189  arith::SelectOp::create(rewriter, loc, mask, bitCast, passthru);
1190  if (!foldedIntraVectorOffset) {
1191  result = dynamicallyExtractSubVector(
1192  rewriter, loc, result, op.getPassThru(),
1193  linearizedInfo.intraDataOffset, origElements);
1194  } else if (!isDivisibleInSize) {
1195  result = staticallyExtractSubvector(
1196  rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1197  }
1198  rewriter.replaceOp(op, result);
1199 
1200  return success();
1201  }
1202 };
1203 
1204 /// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
1205 ///
1206 /// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
1207 /// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
1208 /// (a multi-byte scalar, e.g. i16), where N is some integer.
1209 ///
1210 /// Put differently, this method checks whether this would be valid:
1211 ///
1212 /// vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
1213 ///
1214 /// EXAMPLES:
1215 /// * vector<4xi4> -> i16 - yes (N = 1)
1216 /// * vector<4xi4> -> i8 - yes (N = 2)
1217 /// * vector<3xi4> -> i8 - no (N would have to be 1.5)
1218 /// * vector<3xi2> -> i16 - no (N would have to be 0.5)
1219 static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
1220  Type multiByteScalarTy) {
1221  assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
1222 
1223  int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1224  int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();
1225 
1226  assert(subByteBits < 8 && "Not a sub-byte scalar type!");
1227  assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
1228  assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
1229 
1230  int elemsPerMultiByte = multiByteBits / subByteBits;
1231 
1232  // TODO: This is a bit too restrictive for vectors rank > 1.
1233  return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1234 }
1235 
1236 //===----------------------------------------------------------------------===//
1237 // ConvertVectorTransferRead
1238 //===----------------------------------------------------------------------===//
1239 
1240 // TODO: Document-me
1241 struct ConvertVectorTransferRead final
1242  : OpConversionPattern<vector::TransferReadOp> {
1244 
1245  LogicalResult
1246  matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
1247  ConversionPatternRewriter &rewriter) const override {
1248 
1249  // See #115653
1250  if (op.getVectorType().getRank() != 1)
1251  return rewriter.notifyMatchFailure(op,
1252  "only 1-D vectors are supported ATM");
1253 
1254  auto loc = op.getLoc();
1255  auto containerElemTy =
1256  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1257  Type emulatedElemTy = op.getType().getElementType();
1258  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
1259  int containerBits = containerElemTy.getIntOrFloatBitWidth();
1260 
1261  // Check per-element alignment.
1262  if (containerBits % emulatedBits != 0) {
1263  return rewriter.notifyMatchFailure(
1264  op, "impossible to pack emulated elements into container elements "
1265  "(bit-wise misalignment)");
1266  }
1267  int emulatedPerContainerElem = containerBits / emulatedBits;
1268 
1269  auto origElements = op.getVectorType().getNumElements();
1270 
1271  // Note, per-element-alignment was already verified above.
1272  bool isDivisibleInSize =
1273  fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
1274 
1275  // Pad the padding value with 0s on the left. These bits are discarded and
1276  // thus their values don't matter.
1277  Value padding = adaptor.getPadding();
1278  if (!padding.getType().isInteger()) {
1279  padding = arith::BitcastOp::create(
1280  rewriter, loc,
1281  IntegerType::get(rewriter.getContext(),
1282  padding.getType().getIntOrFloatBitWidth()),
1283  padding);
1284  }
1285  auto newPadding =
1286  arith::ExtUIOp::create(rewriter, loc, containerElemTy, padding);
1287 
1288  auto stridedMetadata =
1289  memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1290 
1291  OpFoldResult linearizedIndices;
1292  memref::LinearizedMemRefInfo linearizedInfo;
1293  std::tie(linearizedInfo, linearizedIndices) =
1295  rewriter, loc, emulatedBits, containerBits,
1296  stridedMetadata.getConstifiedMixedOffset(),
1297  stridedMetadata.getConstifiedMixedSizes(),
1298  stridedMetadata.getConstifiedMixedStrides(),
1299  getAsOpFoldResult(adaptor.getIndices()));
1300 
1301  std::optional<int64_t> foldedIntraVectorOffset =
1302  isDivisibleInSize ? 0
1303  : getConstantIntValue(linearizedInfo.intraDataOffset);
1304 
1305  int64_t maxIntraDataOffset =
1306  foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1307  auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1308  emulatedPerContainerElem);
1309 
1310  auto newRead = vector::TransferReadOp::create(
1311  rewriter, loc, VectorType::get(numElements, containerElemTy),
1312  adaptor.getBase(),
1313  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
1314  newPadding);
1315 
1316  auto bitCast = vector::BitCastOp::create(
1317  rewriter, loc,
1318  VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
1319  newRead);
1320 
1321  Value result = bitCast->getResult(0);
1322  if (!foldedIntraVectorOffset) {
1323  auto zeros = arith::ConstantOp::create(
1324  rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1325  result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
1326  linearizedInfo.intraDataOffset,
1327  origElements);
1328  } else if (!isDivisibleInSize) {
1329  result = staticallyExtractSubvector(
1330  rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1331  }
1332  rewriter.replaceOp(op, result);
1333 
1334  return success();
1335  }
1336 };
1337 } // end anonymous namespace
1338 
1339 //===----------------------------------------------------------------------===//
1340 // RewriteBitCastOfTruncI
1341 //===----------------------------------------------------------------------===//
1342 
1343 namespace {
1344 
1345 /// Helper struct to keep track of the provenance of a contiguous set of bits
1346 /// in a source vector.
1347 struct SourceElementRange {
1348  /// The index of the source vector element that contributes bits to *this.
1349  int64_t sourceElementIdx;
1350  /// The range of bits in the source vector element that contribute to *this.
1351  int64_t sourceBitBegin;
1352  int64_t sourceBitEnd;
1353 };
1354 
1355 struct SourceElementRangeList : public SmallVector<SourceElementRange> {
1356  /// Given the index of a SourceElementRange in the SourceElementRangeList,
1357  /// compute the amount of bits that need to be shifted to the left to get the
1358  /// bits in their final location. This shift amount is simply the sum of the
1359  /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
1360  /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
1361  int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
1362  int64_t res = 0;
1363  for (int64_t i = 0; i < shuffleIdx; ++i)
1364  res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
1365  return res;
1366  }
1367 };
1368 
1369 /// Helper struct to enumerate the source elements and bit ranges that are
1370 /// involved in a bitcast operation.
1371 /// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
1372 /// any 1-D vector shape and any source/target bitwidths.
1373 /// This creates and holds a mapping of the form:
1374 /// [dstVectorElementJ] ==
1375 /// [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ]
1376 /// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as:
1377 /// [0] = {0, [0-8)}
1378 /// [1] = {0, [8-16)}
1379 /// [2] = {0, [16-24)}
1380 /// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
1381 /// [0] = {0, [0, 10)}, {1, [0, 5)}
1382 /// [1] = {1, [5, 10)}, {2, [0, 10)}
1383 struct BitCastBitsEnumerator {
1384  BitCastBitsEnumerator(VectorType sourceVectorType,
1385  VectorType targetVectorType);
1386 
1387  int64_t getMaxNumberOfEntries() {
1388  int64_t numVectors = 0;
1389  for (const auto &l : sourceElementRanges)
1390  numVectors = std::max(numVectors, (int64_t)l.size());
1391  return numVectors;
1392  }
1393 
1394  VectorType sourceVectorType;
1395  VectorType targetVectorType;
1396  SmallVector<SourceElementRangeList> sourceElementRanges;
1397 };
1398 
1399 /// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
1400 /// advantage of high-level information to avoid leaving LLVM to scramble with
1401 /// peephole optimizations.
1402 /// BitCastBitsEnumerator encodes for each element of the target vector the
1403 /// provenance of the bits in the source vector. We can "transpose" this
1404 /// information to build a sequence of shuffles and bitwise ops that will
1405 /// produce the desired result.
1406 //
1407 /// Consider the following motivating example:
1408 /// ```
1409 /// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
1410 /// ```
1411 //
1412 /// BitCastBitsEnumerator contains the following information:
1413 /// ```
1414 /// { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
1415 /// { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
1416 /// { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
1417 /// { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
1418 /// { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
1419 /// { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
1420 /// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
1421 /// {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
1422 /// {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
1423 /// {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
1424 /// {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
1425 /// {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
1426 /// {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
1427 /// {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
1428 /// {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
1429 /// {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
1430 /// {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
1431 /// {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
1432 /// {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
1433 /// {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
1434 /// ```
1435 ///
1436 /// In the above, each row represents one target vector element and each
1437 /// column represents one bit contribution from a source vector element.
1438 /// The algorithm creates vector.shuffle operations (in this case there are 3
1439 /// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
1440 /// algorithm populates the bits as follows:
1441 /// ```
1442 /// src bits 0 ...
1443 /// 1st shuffle |xxxxx |xx |...
1444 /// 2nd shuffle | xxx| xxxxx |...
1445 /// 3rd shuffle | | x|...
1446 /// ```
1447 //
1448 /// The algorithm proceeds as follows:
1449 /// 1. for each vector.shuffle, collect the source vectors that participate in
1450 /// this shuffle. One source vector per target element of the resulting
1451 /// vector.shuffle. If there is no source element contributing bits for the
1452 /// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
1453 /// 2 columns).
1454 /// 2. represent the bitrange in the source vector as a mask. If there is no
1455 /// source element contributing bits for the current vector.shuffle, take 0.
1456 /// 3. shift right by the proper amount to align the source bitrange at
1457 /// position 0. This is exactly the low end of the bitrange. For instance,
1458 /// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
1459 /// shift right by 3 to get the bits contributed by the source element #1
1460 /// into position 0.
1461 /// 4. shift left by the proper amount to to align to the desired position in
1462 /// the result element vector. For instance, the contribution of the second
1463 /// source element for the first row needs to be shifted by `5` to form the
1464 /// first i8 result element.
1465 ///
1466 /// Eventually, we end up building the sequence
1467 /// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
1468 /// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
1469 /// bits extracted from the source vector (i.e. the `shuffle -> and` part).
1470 struct BitCastRewriter {
1471  /// Helper metadata struct to hold the static quantities for the rewrite.
1472  struct Metadata {
1473  SmallVector<int64_t> shuffles;
1474  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1475  };
1476 
1477  BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
1478 
1479  /// Verify that general preconditions for the rewrite are met.
1480  LogicalResult commonPrecondition(PatternRewriter &rewriter,
1481  VectorType preconditionType, Operation *op);
1482 
1483  /// Precompute the metadata for the rewrite.
1485  precomputeMetadata(IntegerType shuffledElementType);
1486 
1487  /// Rewrite one step of the sequence:
1488  /// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
1489  Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
1490  Value initialValue, Value runningResult,
1491  const BitCastRewriter::Metadata &metadata);
1492 
1493 private:
1494  /// Underlying enumerator that encodes the provenance of the bits in the each
1495  /// element of the result vector.
1496  BitCastBitsEnumerator enumerator;
1497 };
1498 
1499 } // namespace
1500 
1501 [[maybe_unused]] static raw_ostream &
1502 operator<<(raw_ostream &os, const SmallVector<SourceElementRangeList> &vec) {
1503  for (const auto &l : vec) {
1504  for (auto it : llvm::enumerate(l)) {
1505  os << "{ " << it.value().sourceElementIdx << ": b@["
1506  << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
1507  << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
1508  }
1509  os << "\n";
1510  }
1511  return os;
1512 }
1513 
1514 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1515  VectorType targetVectorType)
1516  : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1517 
1518  assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1519  "requires -D non-scalable vector type");
1520  assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1521  "requires -D non-scalable vector type");
1522  int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1523  int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1524  LDBG() << "sourceVectorType: " << sourceVectorType;
1525 
1526  int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1527  int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1528  LDBG() << "targetVectorType: " << targetVectorType;
1529 
1530  int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1531  (void)mostMinorSourceDim;
1532  assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1533  "source and target bitwidths must match");
1534 
1535  // Prepopulate one source element range per target element.
1536  sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
1537  for (int64_t resultBit = 0; resultBit < bitwidth;) {
1538  int64_t resultElement = resultBit / targetBitWidth;
1539  int64_t resultBitInElement = resultBit % targetBitWidth;
1540  int64_t sourceElementIdx = resultBit / sourceBitWidth;
1541  int64_t sourceBitInElement = resultBit % sourceBitWidth;
1542  int64_t step = std::min(sourceBitWidth - sourceBitInElement,
1543  targetBitWidth - resultBitInElement);
1544  sourceElementRanges[resultElement].push_back(
1545  {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1546  resultBit += step;
1547  }
1548 }
1549 
1550 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1551  VectorType targetVectorType)
1552  : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1553  LDBG() << "\n" << enumerator.sourceElementRanges;
1554 }
1555 
1556 /// Verify that the precondition type meets the common preconditions for any
1557 /// conversion.
1558 static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
1559  VectorType preconditionType,
1560  Operation *op) {
1561  if (!preconditionType || preconditionType.isScalable())
1562  return rewriter.notifyMatchFailure(op, "scalable vector");
1563 
1564  // TODO: consider relaxing this restriction in the future if we find ways
1565  // to really work with subbyte elements across the MLIR/LLVM boundary.
1566  unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1567  if (bitwidth % 8 != 0)
1568  return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
1569 
1570  return success();
1571 }
1572 
1573 LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
1574  VectorType preconditionType,
1575  Operation *op) {
1576  if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1577  return rewriter.notifyMatchFailure(op, "types are not vector");
1578 
1579  if (!preconditionType || preconditionType.getRank() != 1)
1580  return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
1581 
1582  return commonConversionPrecondition(rewriter, preconditionType, op);
1583 }
1584 
1585 /// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
1586 ///
1587 /// Alignment means that `subByteVecTy` can be packed into a vector of
1588 /// `containerTy` elements. More specifically:
1589 /// 1. The bit-width of `containerTy` is a multiple of the
1590 /// bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
1591 /// this multiple is 4.
1592 /// 2. The multiple from 1. above divides evenly the number of the (trailing)
1593 /// elements in `subByteVecTy`.
1594 ///
1595 /// EXAMPLE 1:
1596 /// `subByteVecTy = vector<2xi4>`, and
1597 /// `containerTy = i16`
1598 ///
1599 /// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
1600 ///
1601 /// EXAMPLE 2:
1602 /// `subByteVecTy = vector<3xi4>`, and
1603 /// `containerTy = i16`
1604 ///
1605 /// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
1606 ///
1607 /// EXAMPLE 3:
1608 /// `subByteVecTy = vector<3xi3>`, and
1609 /// `containerTy = i16`
1610 ///
1611 /// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
1612 ///
1613 /// NOTE: This method assumes that common conversion preconditions are met. In
1614 /// particular, `containerTy` is assumed to be a
1615 /// multi-byte scalar type (e.g., i8, i16, i32).
1616 static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
1617  VectorType subByteVecTy,
1618  Type containerTy,
1619  Operation *op) {
1620  assert(containerTy.isIntOrFloat() &&
1621  "container element type is not a scalar");
1622 
1623  // TODO: This is validating the inputs rather than checking the conditions
1624  // documented above. Replace with an assert.
1625  if (!subByteVecTy)
1626  return rewriter.notifyMatchFailure(op, "not a vector!");
1627 
1628  unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1629  unsigned containerBits = containerTy.getIntOrFloatBitWidth();
1630 
1631  // Enforced by the common pre-conditions.
1632  assert(containerBits % 8 == 0 && "Not a multi-byte scalar type!");
1633 
1634  // TODO: Add support other widths (when/if needed)
1635  if (subByteBits != 2 && subByteBits != 4)
1636  return rewriter.notifyMatchFailure(
1637  op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
1638 
1639  // Condition 1 ("per-element" alignment)
1640  if (containerBits % subByteBits != 0)
1641  return rewriter.notifyMatchFailure(op, "unalagined element types");
1642 
1643  // Condition 2 ("full" alignment)
1644  if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
1645  return rewriter.notifyMatchFailure(
1646  op, "not possible to fit this sub-byte vector type into a vector of "
1647  "the given multi-byte type");
1648 
1649  return success();
1650 }
1651 
1653 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1655  for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1656  shuffleIdx < e; ++shuffleIdx) {
1657  SmallVector<int64_t> shuffles;
1658  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1659 
1660  // Create the attribute quantities for the shuffle / mask / shift ops.
1661  for (auto &srcEltRangeList : enumerator.sourceElementRanges) {
1662  int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1663  ? srcEltRangeList[shuffleIdx].sourceElementIdx
1664  : 0;
1665  shuffles.push_back(sourceElement);
1666 
1667  int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1668  ? srcEltRangeList[shuffleIdx].sourceBitBegin
1669  : 0;
1670  int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1671  ? srcEltRangeList[shuffleIdx].sourceBitEnd
1672  : 0;
1673  IntegerAttr mask = IntegerAttr::get(
1674  shuffledElementType,
1675  llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1676  bitLo, bitHi));
1677  masks.push_back(mask);
1678 
1679  int64_t shiftRight = bitLo;
1680  shiftRightAmounts.push_back(
1681  IntegerAttr::get(shuffledElementType, shiftRight));
1682 
1683  int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1684  shiftLeftAmounts.push_back(
1685  IntegerAttr::get(shuffledElementType, shiftLeft));
1686  }
1687 
1688  result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1689  }
1690  return result;
1691 }
1692 
1693 Value BitCastRewriter::genericRewriteStep(
1694  PatternRewriter &rewriter, Location loc, Value initialValue,
1695  Value runningResult, const BitCastRewriter::Metadata &metadata) {
1696  // Create vector.shuffle from the metadata.
1697  auto shuffleOp = vector::ShuffleOp::create(rewriter, loc, initialValue,
1698  initialValue, metadata.shuffles);
1699 
1700  // Intersect with the mask.
1701  VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1702  auto constOp = arith::ConstantOp::create(
1703  rewriter, loc,
1704  DenseElementsAttr::get(shuffledVectorType, metadata.masks));
1705  Value andValue = arith::AndIOp::create(rewriter, loc, shuffleOp, constOp);
1706 
1707  // Align right on 0.
1708  auto shiftRightConstantOp = arith::ConstantOp::create(
1709  rewriter, loc,
1710  DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
1711  Value shiftedRight =
1712  arith::ShRUIOp::create(rewriter, loc, andValue, shiftRightConstantOp);
1713 
1714  // Shift bits left into their final position.
1715  auto shiftLeftConstantOp = arith::ConstantOp::create(
1716  rewriter, loc,
1717  DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
1718  Value shiftedLeft =
1719  arith::ShLIOp::create(rewriter, loc, shiftedRight, shiftLeftConstantOp);
1720 
1721  runningResult =
1722  runningResult
1723  ? arith::OrIOp::create(rewriter, loc, runningResult, shiftedLeft)
1724  : shiftedLeft;
1725 
1726  return runningResult;
1727 }
1728 
1729 /// Bitcasts the aligned `subByteVec` vector to a vector of i8.
1730 /// Where aligned means it satisfies the alignedConversionPreconditions.
1731 ///
1732 /// Example:
1733 /// vector<16x16xi2> -> vector<16x4xi8>
1734 /// vector<16x16xi4> -> vector<16x8xi8>
1736  Value subByteVec) {
1737  auto srcVecType = cast<VectorType>(subByteVec.getType());
1738  int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1739  assert(8 % srcBitwidth == 0 &&
1740  "Unsupported sub-byte type (not a divisor of i8)");
1741  int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1742  SmallVector<int64_t> vecShape(srcVecType.getShape());
1743  // Adjust last dimension of the vector, so the total size remains the same.
1744  vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1745  auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
1746  return vector::BitCastOp::create(rewriter, loc, i8VecType, subByteVec);
1747 }
1748 
1749 /// Extracts a signed N-bit sequence from each element of a vector of bytes,
1750 /// starting at the specified bit index.
1751 /// The `bitIdx` starts at 0 from the LSB and moves to the left.
1752 ///
1753 /// Example for a single element:
1754 /// Extract numBits=2 starting at bitIdx=2
1755 /// src = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0]
1756 /// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1757 /// target = [. . . . ^ ^ . .]
1758 ///
1759 /// The target sequence is [11](decimal=-1) as signed 2-bit integer.
1760 /// So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer.
1761 ///
1762 /// src = [01 01 11 10]
1763 /// shl = arith.shl(src, 4) -> [11 10 00 00]
1764 /// result = arith.shrsi(shl, 6) -> [11 11 11 11]
1766  Location loc, Value src,
1767  int bitIdx, int numBits) {
1768  auto srcType = cast<VectorType>(src.getType());
1769  Value shl = src;
1770  int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1771  assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1772  "Invalid bitIdx range");
1773  if (bitsToShiftLeft != 0) {
1774  Value shiftLeftValues = arith::ConstantOp::create(
1775  rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
1776  shl = arith::ShLIOp::create(rewriter, loc, src, shiftLeftValues);
1777  }
1778 
1779  int8_t bitsToShiftRight = 8 - numBits;
1780  Value shiftRightValues = arith::ConstantOp::create(
1781  rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1782  Value shr = arith::ShRSIOp::create(rewriter, loc, shl, shiftRightValues);
1783  return shr;
1784 }
1785 
1786 /// Extracts an unsigned N-bit sequence from each element of a vector of bytes,
1787 /// starting at the specified bit index.
1788 /// The `bitIdx` starts at 0 from the LSB and moves to the left.
1789 ///
1790 /// Example for a single element:
1791 /// Extract numBits=2 starting at bitIdx=2
1792 /// src = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0]
1793 /// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1794 /// target = [. . . . ^ ^ . .]
1795 ///
1796 /// The target sequence is [10](decimal=2) as unsigned 2-bit integer.
1797 /// So the result should be [00 00 00 10](decimal=2) as unsigned 8-bit integer.
1798 ///
1799 /// src = [01 01 10 10]
1800 /// mask = [00 00 00 11]
1801 /// shr = arith.shrui(src, 2) = [00 01 01 10]
1802 /// result = arith.andi(shr, mask) = [00 00 00 10]
1803 /// NOTE: Similarly to extractNBitsPerByteAndSignExtendToI8, this could be
1804 /// achieved by using arith::ShLIOp + arith::ShRUIOp instead of the masking.
1805 /// However, by using arith::ShRUIOp + arith::AndIOp, we are eliminating shift
1806 /// left when the index is 0.
1808  Location loc, Value src,
1809  int bitIdx, int numBits) {
1810  assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1811  "Invalid bitIdx range");
1812  auto srcType = cast<VectorType>(src.getType());
1813  int8_t bitsToShiftRight = bitIdx;
1814  Value shr = src;
1815  if (bitsToShiftRight != 0) {
1816  Value shiftRightValues = arith::ConstantOp::create(
1817  rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1818  shr = arith::ShRUIOp::create(rewriter, loc, src, shiftRightValues);
1819  }
1820  if (bitIdx + numBits == 8) {
1821  return shr;
1822  }
1823  uint8_t lowBitsMask = (1 << numBits) - 1;
1824  Value lowBitsMaskValues = arith::ConstantOp::create(
1825  rewriter, loc, DenseElementsAttr::get(srcType, lowBitsMask));
1826  return arith::AndIOp::create(rewriter, loc, shr, lowBitsMaskValues);
1827 }
1828 
1830  std::function<Value(PatternRewriter &, Location, Value, int, int)>;
1831 
1832 /// Rewrite the i4 -> i8 extension into a sequence of shuffles and
1833 /// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1835  Value srcValue, const ExtractNBitsFn &extFn) {
1836  [[maybe_unused]] auto srcVecType = cast<VectorType>(srcValue.getType());
1837  assert(srcVecType.getElementType().isSignlessInteger(4) &&
1838  "Expected i4 type");
1839 
1840  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1841  Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1842 
1843  // 2. Extend i4 elements to i8 elements. Low i4 elemens of each
1844  // byte are place in one vector and the high i4 elements in another vector.
1845  Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1846  Value high = extFn(rewriter, loc, i8Vector, 4, 4);
1847 
1848  // 3. Interleave low and high i8 elements.
1849  return vector::InterleaveOp::create(rewriter, loc, low, high);
1850 }
1851 
1852 /// Rewrite the i2 -> i8 extension into a sequence of shuffles and
1853 /// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1855  Value srcValue, const ExtractNBitsFn &extFn) {
1856  [[maybe_unused]] VectorType srcVecType = cast<VectorType>(srcValue.getType());
1857  assert(srcVecType.getElementType().isSignlessInteger(2) &&
1858  "Expected i2 type");
1859 
1860  // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
1861  Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1862 
1863  // 2. Extract each i2 element
1864  // Positon 0 (bits 0-1)
1865  Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1866  // Position 1 (bits 2-3)
1867  Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1868  // Position 2 (bits 4-5)
1869  Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1870  // Position 3 (bits 6-7)
1871  Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1872 
1873  // 3. Interleave all 4 elements by first interleaving
1874  // even elements and then odd
1875  // vec0 = [0,0,0,0],...
1876  // vec1 = [1,1,1,1],...
1877  // vec2 = [2,2,2,2],...
1878  // vec3 = [3,3,3,3],...
1879  // 02 = [0,2,0,2,0,2,0,2],...
1880  // 13 = [1,3,1,3,1,3,1,3],...
1881  // 0213 = [0,1,2,3,...],...
1882  Value interleave02 = vector::InterleaveOp::create(rewriter, loc, vec0, vec2);
1883  Value interleave13 = vector::InterleaveOp::create(rewriter, loc, vec1, vec3);
1884  return vector::InterleaveOp::create(rewriter, loc, interleave02,
1885  interleave13);
1886 }
1887 
1888 /// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
1889 /// ops to avoid leaving LLVM to scramble with peephole optimizations.
1891  Value srcValue) {
1892  VectorType srcVecType = cast<VectorType>(srcValue.getType());
1893  assert(srcVecType.getElementType().isSignlessInteger(8) &&
1894  "Expected i8 type");
1895 
1896  // 1. De-interleave low and high i8 elements.
1897  auto deinterleaveOp = vector::DeinterleaveOp::create(rewriter, loc, srcValue);
1898 
1899  // 2. Zero out the upper side of each low i8 element.
1900  constexpr int8_t i8LowBitMask = 0x0F;
1901  VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1902  Value zeroOutMask = arith::ConstantOp::create(
1903  rewriter, loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask));
1904  Value zeroOutLow = arith::AndIOp::create(
1905  rewriter, loc, deinterleaveOp.getRes1(), zeroOutMask);
1906 
1907  // 3. Move high i4 values to upper side of the byte.
1908  constexpr int8_t bitsToShift = 4;
1909  auto shiftValues = arith::ConstantOp::create(
1910  rewriter, loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
1911  Value shlHigh = arith::ShLIOp::create(rewriter, loc, deinterleaveOp.getRes2(),
1912  shiftValues);
1913 
1914  // 4. Merge high and low i4 values.
1915  auto mergedHiLowOp = arith::OrIOp::create(rewriter, loc, zeroOutLow, shlHigh);
1916 
1917  // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
1918  auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
1919  return vector::BitCastOp::create(rewriter, loc, i4VecType, mergedHiLowOp);
1920 }
1921 
1922 namespace {
1923 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
1924 /// advantage of high-level information to avoid leaving LLVM to scramble with
1925 /// peephole optimizations.
1926 struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
1928 
1929  LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1930  PatternRewriter &rewriter) const override {
1931  // The source must be a trunc op.
1932  auto truncOp =
1933  bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1934  if (!truncOp)
1935  return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
1936 
1937  // Set up the BitCastRewriter and verify the precondition.
1938  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1939  VectorType targetVectorType = bitCastOp.getResultVectorType();
1940  BitCastRewriter bcr(sourceVectorType, targetVectorType);
1941  if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1942  return failure();
1943 
1944  // Perform the rewrite.
1945  Value truncValue = truncOp.getIn();
1946  auto shuffledElementType =
1947  cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
1948  Value runningResult;
1949  for (const BitCastRewriter ::Metadata &metadata :
1950  bcr.precomputeMetadata(shuffledElementType)) {
1951  runningResult = bcr.genericRewriteStep(
1952  rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1953  }
1954 
1955  // Finalize the rewrite.
1956  bool narrowing = targetVectorType.getElementTypeBitWidth() <=
1957  shuffledElementType.getIntOrFloatBitWidth();
1958  if (narrowing) {
1959  if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1960  rewriter.replaceOp(bitCastOp, runningResult);
1961  } else {
1962  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
1963  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1964  }
1965  } else {
1966  if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1967  rewriter.replaceOp(bitCastOp, runningResult);
1968  } else {
1969  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1970  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1971  }
1972  }
1973 
1974  return success();
1975  }
1976 };
1977 } // namespace
1978 
1979 //===----------------------------------------------------------------------===//
1980 // RewriteExtOfBitCast
1981 //===----------------------------------------------------------------------===//
1982 
1983 namespace {
1984 /// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that
1985 /// take advantage of high-level information to avoid leaving LLVM to scramble
1986 /// with peephole optimizations.
1987 template <typename ExtOpType>
1988 struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
1990 
1991  RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
1992  : OpRewritePattern<ExtOpType>(context, benefit) {}
1993 
1994  LogicalResult matchAndRewrite(ExtOpType extOp,
1995  PatternRewriter &rewriter) const override {
1996  // The source must be a bitcast op.
1997  auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
1998  if (!bitCastOp)
1999  return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
2000 
2001  // Set up the BitCastRewriter and verify the precondition.
2002  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
2003  VectorType targetVectorType = bitCastOp.getResultVectorType();
2004  BitCastRewriter bcr(sourceVectorType, targetVectorType);
2005  if (failed(bcr.commonPrecondition(
2006  rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
2007  return failure();
2008 
2009  // Perform the rewrite.
2010  Value runningResult;
2011  Value sourceValue = bitCastOp.getSource();
2012  auto shuffledElementType =
2013  cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
2014  for (const BitCastRewriter::Metadata &metadata :
2015  bcr.precomputeMetadata(shuffledElementType)) {
2016  runningResult = bcr.genericRewriteStep(
2017  rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
2018  }
2019 
2020  // Finalize the rewrite.
2021  bool narrowing =
2022  cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
2023  shuffledElementType.getIntOrFloatBitWidth();
2024  if (narrowing) {
2025  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
2026  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2027  } else {
2028  rewriter.replaceOpWithNewOp<ExtOpType>(
2029  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2030  }
2031 
2032  return success();
2033  }
2034 };
2035 
2036 /// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
2037 /// bitwise ops that take advantage of high-level information to avoid leaving
2038 /// LLVM to scramble with peephole optimizations. Templated to choose between
2039 /// signed and unsigned conversions.
2040 ///
2041 /// EXAMPLE 1 (signed):
2042 /// arith.extsi %in : vector<8xi4> to vector<8xi32>
2043 /// is rewriten as:
2044 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2045 /// %1 = arith.shli %0, 4 : vector<4xi8>
2046 /// %2 = arith.shrsi %1, 4 : vector<4xi8>
2047 /// %3 = arith.shrsi %0, 4 : vector<4xi8>
2048 /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
2049 /// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
2050 ///
2051 /// EXAMPLE 2 (fp):
2052 /// arith.sitofp %in : vector<8xi4> to vector<8xf32>
2053 /// is rewriten as:
2054 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2055 /// %1 = arith.shli %0, 4 : vector<4xi8>
2056 /// %2 = arith.shrsi %1, 4 : vector<4xi8>
2057 /// %3 = arith.shrsi %0, 4 : vector<4xi8>
2058 /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
2059 /// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
2060 ///
2061 /// EXAMPLE 3 (unsigned):
2062 /// arith.extui %in : vector<8xi4> to vector<8xi32>
2063 /// is rewritten as:
2064 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2065 /// %1 = arith.andi %0, 15 : vector<4xi8>
2066 /// %2 = arith.shrui %0, 4 : vector<4xi8>
2067 /// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
2068 /// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
2069 ///
2070 template <typename ConversionOpType, bool isSigned>
2071 struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
2073 
2074  LogicalResult matchAndRewrite(ConversionOpType conversionOp,
2075  PatternRewriter &rewriter) const override {
2076  // Verify the preconditions.
2077  Value srcValue = conversionOp.getIn();
2078  VectorType srcVecType = dyn_cast<VectorType>(srcValue.getType());
2079  VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
2080 
2081  if (failed(
2082  commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
2083  return failure();
2084 
2085  // Check general alignment preconditions.
2087  rewriter, srcVecType,
2088  /*containerTy=*/rewriter.getI8Type(), conversionOp)))
2089  return failure();
2090 
2091  // Perform the rewrite.
2092  Location loc = conversionOp.getLoc();
2093  const auto &extFn = isSigned ? extractNBitsPerByteAndSignExtendToI8
2095  Value subByteExt;
2096  switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
2097  case 2:
2098  subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn);
2099  break;
2100  case 4:
2101  subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn);
2102  break;
2103  default:
2104  return failure();
2105  }
2106 
2107  // Finalize the rewrite.
2108  rewriter.replaceOpWithNewOp<ConversionOpType>(
2109  conversionOp, conversionOp.getType(), subByteExt);
2110  return success();
2111  }
2112 };
2113 
2114 /// Rewrite the i8 -> i4 part of any truncation into a deinterleave and
2115 /// bitwise ops that take advantage of high-level information to avoid leaving
2116 /// LLVM to scramble with peephole optimizations.
2117 ///
2118 /// For example:
2119 /// arith.trunci %in : vector<8xi32> to vector<8xi4>
2120 ///
2121 /// is rewriten as:
2122 ///
2123 /// %cst = arith.constant dense<15> : vector<4xi8>
2124 /// %cst_0 = arith.constant dense<4> : vector<4xi8>
2125 /// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
2126 /// %2 = arith.andi %0, %cst : vector<4xi8>
2127 /// %3 = arith.shli %1, %cst_0 : vector<4xi8>
2128 /// %4 = arith.ori %2, %3 : vector<4xi8>
2129 /// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
2130 ///
2131 struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
2133 
2134  LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
2135  PatternRewriter &rewriter) const override {
2136  // Verify the preconditions.
2137  Value srcValue = truncOp.getIn();
2138  auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
2139  auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
2140  if (!srcVecType || !dstVecType)
2141  return failure();
2142 
2143  if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
2144  return failure();
2145 
2146  // TODO: Add support for truncating to i2.
2147  if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
2148  return failure();
2149 
2150  // Check general alignment preconditions. We invert the src/dst type order
2151  // to reuse the existing precondition logic.
2153  rewriter, dstVecType,
2154  /*containerTy=*/rewriter.getI8Type(), truncOp)))
2155  return failure();
2156 
2157  // Create a new iX -> i8 truncation op.
2158  Location loc = truncOp.getLoc();
2159  auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
2160  Value i8TruncVal =
2161  arith::TruncIOp::create(rewriter, loc, i8VecType, srcValue);
2162 
2163  // Rewrite the i8 -> i4 truncation part.
2164  Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
2165 
2166  // Finalize the rewrite.
2167  rewriter.replaceOp(truncOp, subByteTrunc);
2168  return success();
2169  }
2170 };
2171 
2172 /// Rewrite a sub-byte vector transpose into a sequence of instructions that
2173 /// perform the transpose on wider (byte) element types.
2174 ///
2175 /// EXAMPLE:
2176 /// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
2177 ///
2178 /// is rewritten as:
2179 ///
2180 /// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
2181 /// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
2182 /// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
2183 ///
2184 struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
2186 
2187  RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
2188  : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
2189 
2190  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
2191  PatternRewriter &rewriter) const override {
2192  // Precondition: sub-byte integer transpose.
2193  constexpr unsigned minNativeBitwidth = 8;
2194  VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
2195  if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
2196  srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
2197  return rewriter.notifyMatchFailure(transposeOp,
2198  "not a sub-byte transpose");
2199  }
2200 
2201  // Perform the rewrite.
2202  Location loc = transposeOp.getLoc();
2203  // Signed/unsigned interpretation shouldn't matter here as we are just
2204  // transposing the elements and truncating them back to the original size.
2205  // TODO: Use unsigned extension (more efficient) when emulation or backend
2206  // support is available.
2207  auto srcNativeVecType = srcSubByteVecType.cloneWith(
2208  std::nullopt, rewriter.getIntegerType(minNativeBitwidth));
2209  Value extOp = arith::ExtSIOp::create(rewriter, loc, srcNativeVecType,
2210  transposeOp.getVector());
2211  Value newTranspose = vector::TransposeOp::create(
2212  rewriter, loc, extOp, transposeOp.getPermutation());
2213  VectorType dstSubByteVecType = transposeOp.getResultVectorType();
2214  rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
2215  newTranspose);
2216  return success();
2217  }
2218 };
2219 
2220 } // namespace
2221 
2222 //===----------------------------------------------------------------------===//
2223 // Public Interface Definition
2224 //===----------------------------------------------------------------------===//
2225 
2226 // The emulated type is inferred from the converted memref type.
2227 void vector::populateVectorNarrowTypeEmulationPatterns(
2228  const arith::NarrowTypeEmulationConverter &typeConverter,
2229  RewritePatternSet &patterns, bool disableAtomicRMW) {
2230 
2231  // Populate `vector.*` conversion patterns.
2232  // TODO: #119553 support atomicity
2233  patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
2234  ConvertVectorMaskedStore, ConvertVectorTransferRead>(
2235  typeConverter, patterns.getContext());
2236 
2237  // Populate `vector.*` store conversion patterns. The caller can choose
2238  // to avoid emitting atomic operations and reduce it to read-modify-write
2239  // sequence for stores if it is known there are no thread contentions.
2240  patterns.insert<ConvertVectorStore>(patterns.getContext(), disableAtomicRMW);
2241 }
2242 
2243 void vector::populateVectorNarrowTypeRewritePatterns(
2245  // TODO: Document what the emulated type is.
2246  patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
2247  RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
2248  benefit);
2249 
2250  // Patterns for aligned cases. We set higher priority as they are expected to
2251  // generate better performance for aligned cases.
2252  // The container type is always i8.
2253  patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
2254  RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
2255  RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
2256  benefit.getBenefit() + 1);
2257  // The container type is always i8.
2258  patterns
2259  .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
2260  RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
2261  patterns.getContext(), benefit.getBenefit() + 1);
2262 }
2263 
2264 // The container type is always i8.
2265 void vector::populateVectorTransposeNarrowTypeRewritePatterns(
2267  patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
2268 }
static Type getElementType(Type type)
Determine the element type of type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, Location loc, VectorValue vector, int64_t extractOffset, int64_t sliceNumElements, int64_t insertOffset)
Extract sliceNumElements from source vector at extractOffset, and insert it into an empty vector at i...
std::function< Value(PatternRewriter &, Location, Value, int, int)> ExtractNBitsFn
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops to avoid leaving LLVM t...
TypedValue< MemRefType > MemRefValue
static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy)
Emulate a vector load for emulatedElemTy using containerElemTy
TypedValue< VectorType > VectorValue
static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, VectorType downcastType, VectorType upcastType, Value mask, Value trueValue, Value falseValue)
Downcast two values to downcastType, then select values based on mask, and casts the result to upcast...
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i4 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToInsert)
Inserts 1-D subvector into a 1-D vector.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector.
static void atomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value storeIdx, VectorValue valueToStore, Value mask)
Emits memref.generic_atomic_rmw op to store a subbyte-sized value to a byte in linearizedMemref,...
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, Value src, int64_t offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecTy, Type containerTy, Operation *op)
Verify that subByteVecTy (vector) and containerTy (scalar) are aligned.
static void nonAtomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask)
Generate a non-atomic read-modify-write sequence for storing to the emulated type.
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, Value subByteVec)
Bitcasts the aligned subByteVec vector to a vector of i8.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts an unsigned N-bit sequence from each element of a vector of bytes, starting at the specified...
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i2 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts a signed N-bit sequence from each element of a vector of bytes, starting at the specified bi...
Base type for affine expression.
Definition: AffineExpr.h:68
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
IntegerType getI4Type()
Definition: Builders.cpp:56
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:52
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:276
IndexType getIndexType()
Definition: Builders.cpp:50
IntegerType getI8Type()
Definition: Builders.cpp:58
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_type_range getResultTypes()
Definition: Operation.h:428
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:51
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:320
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Definition: PatternMatch.h:297
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition: MemRefUtils.h:50
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.