MLIR  19.0.0git
CodegenUtils.cpp
Go to the documentation of this file.
1 //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "CodegenUtils.h"
10 #include "SparseTensorDescriptor.h"
11 
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/Types.h"
20 #include "mlir/IR/Value.h"
21 #include <optional>
22 
23 using namespace mlir;
24 using namespace mlir::sparse_tensor;
25 
26 //===----------------------------------------------------------------------===//
27 // ExecutionEngine/SparseTensorUtils helper functions.
28 //===----------------------------------------------------------------------===//
29 
31  switch (width) {
32  case 64:
33  return OverheadType::kU64;
34  case 32:
35  return OverheadType::kU32;
36  case 16:
37  return OverheadType::kU16;
38  case 8:
39  return OverheadType::kU8;
40  case 0:
41  return OverheadType::kIndex;
42  }
43  llvm_unreachable("Unsupported overhead bitwidth");
44 }
45 
47  if (tp.isIndex())
48  return OverheadType::kIndex;
49  if (auto intTp = dyn_cast<IntegerType>(tp))
50  return overheadTypeEncoding(intTp.getWidth());
51  llvm_unreachable("Unknown overhead type");
52 }
53 
55  switch (ot) {
57  return builder.getIndexType();
58  case OverheadType::kU64:
59  return builder.getIntegerType(64);
60  case OverheadType::kU32:
61  return builder.getIntegerType(32);
62  case OverheadType::kU16:
63  return builder.getIntegerType(16);
64  case OverheadType::kU8:
65  return builder.getIntegerType(8);
66  }
67  llvm_unreachable("Unknown OverheadType");
68 }
69 
71 mlir::sparse_tensor::posTypeEncoding(SparseTensorEncodingAttr enc) {
72  return overheadTypeEncoding(enc.getPosWidth());
73 }
74 
76 mlir::sparse_tensor::crdTypeEncoding(SparseTensorEncodingAttr enc) {
77  return overheadTypeEncoding(enc.getCrdWidth());
78 }
79 
80 // TODO: we ought to add some `static_assert` tests to ensure that the
81 // `STEA::get{Pos,Crd}Type` methods agree with `getOverheadType(builder,
82 // {pos,crd}OverheadTypeEncoding(enc))`
83 
84 // TODO: Adjust the naming convention for the constructors of
85 // `OverheadType` so we can use the `MLIR_SPARSETENSOR_FOREVERY_O` x-macro
86 // here instead of `MLIR_SPARSETENSOR_FOREVERY_FIXED_O`; to further reduce
87 // the possibility of typo bugs or things getting out of sync.
89  switch (ot) {
91  return "0";
92 #define CASE(ONAME, O) \
93  case OverheadType::kU##ONAME: \
94  return #ONAME;
96 #undef CASE
97  }
98  llvm_unreachable("Unknown OverheadType");
99 }
100 
103 }
104 
106  if (elemTp.isF64())
107  return PrimaryType::kF64;
108  if (elemTp.isF32())
109  return PrimaryType::kF32;
110  if (elemTp.isF16())
111  return PrimaryType::kF16;
112  if (elemTp.isBF16())
113  return PrimaryType::kBF16;
114  if (elemTp.isInteger(64))
115  return PrimaryType::kI64;
116  if (elemTp.isInteger(32))
117  return PrimaryType::kI32;
118  if (elemTp.isInteger(16))
119  return PrimaryType::kI16;
120  if (elemTp.isInteger(8))
121  return PrimaryType::kI8;
122  if (auto complexTp = dyn_cast<ComplexType>(elemTp)) {
123  auto complexEltTp = complexTp.getElementType();
124  if (complexEltTp.isF64())
125  return PrimaryType::kC64;
126  if (complexEltTp.isF32())
127  return PrimaryType::kC32;
128  }
129  llvm_unreachable("Unknown primary type");
130 }
131 
133  switch (pt) {
134 #define CASE(VNAME, V) \
135  case PrimaryType::k##VNAME: \
136  return #VNAME;
138 #undef CASE
139  }
140  llvm_unreachable("Unknown PrimaryType");
141 }
142 
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // Misc code generators.
149 //===----------------------------------------------------------------------===//
150 
152  Type dstTp) {
153  const Type srcTp = value.getType();
154  if (srcTp == dstTp)
155  return value;
156 
157  // int <=> index
158  if (isa<IndexType>(srcTp) || isa<IndexType>(dstTp))
159  return builder.create<arith::IndexCastOp>(loc, dstTp, value);
160 
161  const auto srcIntTp = dyn_cast_or_null<IntegerType>(srcTp);
162  const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false;
163  return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
164 }
165 
167  Value elem, Type dstTp) {
168  if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
169  // Scalars can only be converted to 0-ranked tensors.
170  assert(rtp.getRank() == 0);
171  elem = sparse_tensor::genCast(builder, loc, elem, rtp.getElementType());
172  return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
173  }
174  return sparse_tensor::genCast(builder, loc, elem, dstTp);
175 }
176 
178  Value s) {
179  Value load = builder.create<memref::LoadOp>(loc, mem, s);
180  if (!isa<IndexType>(load.getType())) {
181  if (load.getType().getIntOrFloatBitWidth() < 64)
182  load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load);
183  load =
184  builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load);
185  }
186  return load;
187 }
188 
189 mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
190  if (isa<FloatType>(tp))
191  return builder.getFloatAttr(tp, 1.0);
192  if (isa<IndexType>(tp))
193  return builder.getIndexAttr(1);
194  if (auto intTp = dyn_cast<IntegerType>(tp))
195  return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1));
196  if (isa<RankedTensorType, VectorType>(tp)) {
197  auto shapedTp = cast<ShapedType>(tp);
198  if (auto one = getOneAttr(builder, shapedTp.getElementType()))
199  return DenseElementsAttr::get(shapedTp, one);
200  }
201  llvm_unreachable("Unsupported attribute type");
202 }
203 
205  Value v) {
206  Type tp = v.getType();
207  Value zero = constantZero(builder, loc, tp);
208  if (isa<FloatType>(tp))
209  return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
210  zero);
211  if (tp.isIntOrIndex())
212  return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
213  zero);
214  if (dyn_cast<ComplexType>(tp))
215  return builder.create<complex::NotEqualOp>(loc, v, zero);
216  llvm_unreachable("Non-numeric type");
217 }
218 
220  OpBuilder &builder, Location loc, SmallVectorImpl<Value> &dstShape,
221  ArrayRef<Value> srcShape, ArrayRef<Size> staticDstShape,
222  ArrayRef<ReassociationIndices> reassociation) {
223  // Collapse shape.
224  if (reassociation.size() < srcShape.size()) {
225  unsigned start = 0;
226  for (const auto &map : llvm::enumerate(reassociation)) {
227  auto dstDim = constantIndex(builder, loc, 1);
228  for (unsigned i = start; i < start + map.value().size(); i++) {
229  dstDim = builder.create<arith::MulIOp>(loc, dstDim, srcShape[i]);
230  }
231  dstShape.push_back(dstDim);
232  start = start + map.value().size();
233  }
234  assert(start == srcShape.size());
235  return;
236  }
237 
238  // Expand shape.
239  assert(reassociation.size() == srcShape.size());
240  unsigned start = 0;
241  // Expand the i-th dimension in srcShape.
242  for (unsigned i = 0, size = srcShape.size(); i < size; i++) {
243  const auto &map = reassociation[i];
244  auto srcDim = srcShape[i];
245  // Iterate through dimensions expanded from the i-th dimension.
246  for (unsigned j = start; j < start + map.size(); j++) {
247  // There can be only one dynamic sized dimension among dimensions
248  // expanded from the i-th dimension in srcShape.
249  // For example, if srcDim = 8, then the expanded shape could be <2x?x2>,
250  // but not <2x?x?>.
251  if (staticDstShape[j] == ShapedType::kDynamic) {
252  // The expanded dimension has dynamic size. We compute the dimension
253  // by dividing srcDim by the product of the static dimensions.
254  Size product = 1;
255  for (unsigned k = start; k < start + map.size(); k++) {
256  if (staticDstShape[k] != ShapedType::kDynamic) {
257  product *= staticDstShape[k];
258  }
259  }
260  // Compute the dynamic dimension size.
261  Value productVal = constantIndex(builder, loc, product);
262  Value dynamicSize =
263  builder.create<arith::DivUIOp>(loc, srcDim, productVal);
264  dstShape.push_back(dynamicSize);
265  } else {
266  // The expanded dimension is statically known.
267  dstShape.push_back(constantIndex(builder, loc, staticDstShape[j]));
268  }
269  }
270  start = start + map.size();
271  }
272  assert(start == staticDstShape.size());
273 }
274 
276  OpBuilder &builder, Location loc,
277  ArrayRef<ReassociationIndices> reassociation, // NOLINT
278  ValueRange srcSizes, ValueRange srcCvs, // NOLINT
279  ValueRange dstSizes, SmallVectorImpl<Value> &dstCvs) {
280  const unsigned srcRank = srcSizes.size();
281  const unsigned dstRank = dstSizes.size();
282  assert(srcRank == srcCvs.size() && "Source rank mismatch");
283  const bool isCollapse = srcRank > dstRank;
284  const ValueRange sizes = isCollapse ? srcSizes : dstSizes;
285  // Iterate over reassociation map.
286  unsigned i = 0;
287  unsigned start = 0;
288  for (const auto &map : llvm::enumerate(reassociation)) {
289  // Prepare strides information in dimension slice.
290  Value linear = constantIndex(builder, loc, 1);
291  for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
292  linear = builder.create<arith::MulIOp>(loc, linear, sizes[j]);
293  }
294  // Start expansion.
295  Value val;
296  if (!isCollapse)
297  val = srcCvs[i];
298  // Iterate over dimension slice.
299  for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
300  linear = builder.create<arith::DivUIOp>(loc, linear, sizes[j]);
301  if (isCollapse) {
302  const Value mul = builder.create<arith::MulIOp>(loc, srcCvs[j], linear);
303  val = val ? builder.create<arith::AddIOp>(loc, val, mul) : mul;
304  } else {
305  const Value old = val;
306  val = builder.create<arith::DivUIOp>(loc, val, linear);
307  assert(dstCvs.size() == j);
308  dstCvs.push_back(val);
309  val = builder.create<arith::RemUIOp>(loc, old, linear);
310  }
311  }
312  // Finalize collapse.
313  if (isCollapse) {
314  assert(dstCvs.size() == i);
315  dstCvs.push_back(val);
316  }
317  start += map.value().size();
318  i++;
319  }
320  assert(dstCvs.size() == dstRank);
321 }
322 
323 FlatSymbolRefAttr mlir::sparse_tensor::getFunc(ModuleOp module, StringRef name,
324  TypeRange resultType,
325  ValueRange operands,
326  EmitCInterface emitCInterface) {
327  MLIRContext *context = module.getContext();
328  auto result = SymbolRefAttr::get(context, name);
329  auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
330  if (!func) {
331  OpBuilder moduleBuilder(module.getBodyRegion());
332  func = moduleBuilder.create<func::FuncOp>(
333  module.getLoc(), name,
334  FunctionType::get(context, operands.getTypes(), resultType));
335  func.setPrivate();
336  if (static_cast<bool>(emitCInterface))
337  func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
338  UnitAttr::get(context));
339  }
340  return result;
341 }
342 
344  OpBuilder &builder, Location loc, StringRef name, TypeRange resultType,
345  ValueRange operands, EmitCInterface emitCInterface) {
346  auto module = builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
347  FlatSymbolRefAttr fn =
348  getFunc(module, name, resultType, operands, emitCInterface);
349  return builder.create<func::CallOp>(loc, resultType, fn, operands);
350 }
351 
353  return LLVM::LLVMPointerType::get(ctx);
354 }
355 
357  return getOpaquePointerType(builder.getContext());
358 }
359 
361  unsigned sz, Type tp, bool staticShape) {
362  if (staticShape) {
363  auto memTp = MemRefType::get({sz}, tp);
364  return builder.create<memref::AllocaOp>(loc, memTp);
365  }
366  return genAlloca(builder, loc, constantIndex(builder, loc, sz), tp);
367 }
368 
370  Type tp) {
371  auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
372  return builder.create<memref::AllocaOp>(loc, memTp, ValueRange{sz});
373 }
374 
376  Type tp) {
377  return builder.create<memref::AllocaOp>(loc, MemRefType::get({}, tp));
378 }
379 
381  ValueRange values) {
382  const unsigned sz = values.size();
383  assert(sz >= 1);
384  Value buffer = genAlloca(builder, loc, sz, values[0].getType());
385  for (unsigned i = 0; i < sz; i++) {
386  Value idx = constantIndex(builder, loc, i);
387  builder.create<memref::StoreOp>(loc, values[i], buffer, idx);
388  }
389  return buffer;
390 }
391 
393  RankedTensorType tensorTp,
394  ValueRange sizes) {
395  Type elemTp = tensorTp.getElementType();
396  auto shape = tensorTp.getShape();
397  auto memTp = MemRefType::get(shape, elemTp);
398  SmallVector<Value> dynamicSizes;
399  for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) {
400  if (shape[i] == ShapedType::kDynamic)
401  dynamicSizes.push_back(sizes[i]);
402  }
403  Value mem = builder.create<memref::AllocOp>(loc, memTp, dynamicSizes);
404  Value zero = constantZero(builder, loc, elemTp);
405  builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{mem});
406  return mem;
407 }
408 
410  Value buffer) {
411  builder.create<memref::DeallocOp>(loc, buffer);
412 }
413 
415  SmallVectorImpl<Value> &sizes,
416  Location loc, Value src) {
417  const Dimension dimRank = getSparseTensorType(src).getDimRank();
418  for (Dimension d = 0; d < dimRank; d++)
419  sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, d));
420 }
421 
423  for (; isa<scf::ForOp>(op->getParentOp()) ||
424  isa<scf::WhileOp>(op->getParentOp()) ||
425  isa<scf::ParallelOp>(op->getParentOp()) ||
426  isa<scf::IfOp>(op->getParentOp());
427  op = op->getParentOp())
428  ;
429  return op;
430 }
431 
433  OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
434  function_ref<void(ArrayRef<Value>, Value)> callback) {
435  if (!order)
436  order = builder.getMultiDimIdentityMap(attr.getType().getRank());
437 
438  auto stt = SparseTensorType(getRankedTensorType(attr));
439  const Dimension dimRank = stt.getDimRank();
440  const auto coordinates = attr.getIndices().getValues<IntegerAttr>();
441  const auto values = attr.getValues().getValues<Attribute>();
442 
443  // This is like the `Element<V>` class in the runtime library, but for
444  // MLIR attributes. In the future we may want to move this out into
445  // a proper class definition to help improve code legibility (e.g.,
446  // `first` -> `coords`, `second` -> `value`) as well as being able
447  // to factor out analogues of `ElementLT<V>` for the sort below, etc.
448  using ElementAttr = std::pair<SmallVector<IntegerAttr>, Attribute>;
449 
450  // Construct the COO from the SparseElementsAttr.
452  for (size_t i = 0, nse = values.size(); i < nse; i++) {
453  elems.emplace_back();
454  elems.back().second = values[i];
455  auto &coords = elems.back().first;
456  coords.reserve(dimRank);
457  for (Dimension d = 0; d < dimRank; d++)
458  coords.push_back(coordinates[i * dimRank + d]);
459  }
460 
461  // Sorts the sparse element attribute based on coordinates.
462  std::sort(elems.begin(), elems.end(),
463  [order](const ElementAttr &lhs, const ElementAttr &rhs) {
464  if (std::addressof(lhs) == std::addressof(rhs))
465  return false;
466 
467  auto lhsCoords = llvm::map_to_vector(
468  lhs.first, [](IntegerAttr i) { return i.getInt(); });
469  auto rhsCoords = llvm::map_to_vector(
470  rhs.first, [](IntegerAttr i) { return i.getInt(); });
471 
472  SmallVector<int64_t, 4> lhsLvlCrds = order.compose(lhsCoords);
473  SmallVector<int64_t, 4> rhsLvlCrds = order.compose(rhsCoords);
474  // Sort the element based on the lvl coordinates.
475  for (Level l = 0; l < order.getNumResults(); l++) {
476  if (lhsLvlCrds[l] == rhsLvlCrds[l])
477  continue;
478  return lhsLvlCrds[l] < rhsLvlCrds[l];
479  }
480  llvm_unreachable("no equal coordinate in sparse element attr");
481  });
482 
483  SmallVector<Value> cvs;
484  cvs.reserve(dimRank);
485  for (size_t i = 0, nse = values.size(); i < nse; i++) {
486  // Remap coordinates.
487  cvs.clear();
488  for (Dimension d = 0; d < dimRank; d++) {
489  auto crd = elems[i].first[d].getInt();
490  cvs.push_back(builder.create<arith::ConstantIndexOp>(loc, crd));
491  }
492  // Remap value.
493  Value val;
494  if (isa<ComplexType>(attr.getElementType())) {
495  auto valAttr = cast<ArrayAttr>(elems[i].second);
496  val = builder.create<complex::ConstantOp>(loc, attr.getElementType(),
497  valAttr);
498  } else {
499  auto valAttr = cast<TypedAttr>(elems[i].second);
500  val = builder.create<arith::ConstantOp>(loc, valAttr);
501  }
502  assert(val);
503  callback(cvs, val);
504  }
505 }
506 
508  size_t size, Value mem,
509  size_t offsetIdx, Value offsetVal) {
510 #ifndef NDEBUG
511  const auto memTp = cast<MemRefType>(mem.getType());
512  assert(memTp.getRank() == 1);
513  const Size memSh = memTp.getDimSize(0);
514  assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(size));
515  assert(offsetIdx == 0 || offsetIdx < size);
516 #endif // NDEBUG
518  vs.reserve(size);
519  for (unsigned i = 0; i < size; i++) {
520  Value v = builder.create<memref::LoadOp>(loc, mem,
521  constantIndex(builder, loc, i));
522  if (i == offsetIdx && offsetVal)
523  v = builder.create<arith::AddIOp>(loc, v, offsetVal);
524  vs.push_back(v);
525  }
526  return vs;
527 }
528 
530  ValueRange vs, size_t offsetIdx, Value offsetVal) {
531 #ifndef NDEBUG
532  const size_t vsize = vs.size();
533  const auto memTp = cast<MemRefType>(mem.getType());
534  assert(memTp.getRank() == 1);
535  const Size memSh = memTp.getDimSize(0);
536  assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(vsize));
537  assert(offsetIdx == 0 || offsetIdx < vsize);
538 #endif // NDEBUG
539  for (const auto &v : llvm::enumerate(vs)) {
540  const Value w =
541  (offsetIdx == v.index() && offsetVal)
542  ? builder.create<arith::AddIOp>(loc, v.value(), offsetVal)
543  : v.value();
544  builder.create<memref::StoreOp>(loc, w, mem,
545  constantIndex(builder, loc, v.index()));
546  }
547 }
548 
551  auto tTp = llvm::cast<TensorType>(tensor.getType());
552  auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType());
553  return builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor)
554  .getResult();
555 }
556 
558  Value tensor, Level lvl) {
559  const auto srcTp = getSparseTensorType(tensor);
560  const Type posTp = srcTp.getPosType();
561  const Type memTp = get1DMemRefType(posTp, /*withLayout=*/false);
562  return builder.create<ToPositionsOp>(loc, memTp, tensor,
563  builder.getIndexAttr(lvl));
564 }
565 
567  Value tensor, Level lvl) {
568  const auto srcTp = getSparseTensorType(tensor);
569  const Type crdTp = srcTp.getCrdType();
570  const Type memTp =
571  get1DMemRefType(crdTp, /*withLayout=*/lvl >= srcTp.getAoSCOOStart());
572  return builder.create<ToCoordinatesOp>(loc, memTp, tensor,
573  builder.getIndexAttr(lvl));
574 }
575 
577  Value tensor) {
578  const auto srcTp = getSparseTensorType(tensor);
579  const Type crdTp = srcTp.getCrdType();
580  const Type memTp = get1DMemRefType(crdTp, /*withLayout=*/false);
581  return builder.create<ToCoordinatesBufferOp>(loc, memTp, tensor);
582 }
583 
585  Value tensor) {
586  RankedTensorType srcTp = getRankedTensorType(tensor);
587  Type valTp = get1DMemRefType(srcTp.getElementType(),
588  /*withLayout=*/false);
589  return builder.create<ToValuesOp>(loc, valTp, tensor);
590 }
591 
593  Value tensor) {
594  return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc);
595 }
596 
598  Value tensor, Dimension dim) {
599  auto enc = getSparseTensorEncoding(tensor.getType());
600  assert(enc && enc.isSlice());
601  std::optional<unsigned> offset = enc.getStaticDimSliceOffset(dim);
602  if (offset.has_value())
603  return constantIndex(builder, loc, *offset);
604  return builder.create<ToSliceOffsetOp>(loc, tensor, APInt(64, dim));
605 }
606 
608  Value tensor, Dimension dim) {
609  auto enc = getSparseTensorEncoding(tensor.getType());
610  assert(enc && enc.isSlice());
611  std::optional<unsigned> stride = enc.getStaticDimSliceStride(dim);
612  if (stride.has_value())
613  return constantIndex(builder, loc, *stride);
614  return builder.create<ToSliceStrideOp>(loc, tensor, APInt(64, dim));
615 }
616 
618  SparseTensorType stt, Value tensor,
619  /*out*/ SmallVectorImpl<Value> &dimSizesValues,
620  /*out*/ Value &dimSizesBuffer) {
621  // Construct the dimension **shapes** buffer. The buffer contains the static
622  // size per dimension, or otherwise a zero for a dynamic size.
623  Dimension dimRank = stt.getDimRank();
624  dimSizesValues.clear();
625  dimSizesValues.reserve(dimRank);
626  for (const Size sz : stt.getDimShape()) {
627  const auto s = ShapedType::isDynamic(sz) ? 0 : sz;
628  dimSizesValues.push_back(constantIndex(builder, loc, s));
629  }
630  Value dimShapesBuffer = allocaBuffer(builder, loc, dimSizesValues);
631  // Create the `CheckedSparseTensorReader`. This reader performs a
632  // consistency check on the static sizes, but accepts any size
633  // of each dimension with a dynamic size.
634  Type opaqueTp = getOpaquePointerType(builder);
635  Type eltTp = stt.getElementType();
636  Value valTp = constantPrimaryTypeEncoding(builder, loc, eltTp);
637  Value reader =
638  createFuncCall(builder, loc, "createCheckedSparseTensorReader", opaqueTp,
639  {tensor, dimShapesBuffer, valTp}, EmitCInterface::On)
640  .getResult(0);
641  // For static shapes, the shape buffer can be used right away. For dynamic
642  // shapes, use the information from the reader to construct a buffer that
643  // supplies the actual size for each dynamic dimension.
644  dimSizesBuffer = dimShapesBuffer;
645  if (stt.hasDynamicDimShape()) {
646  Type indexTp = builder.getIndexType();
647  auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp);
648  dimSizesBuffer =
649  createFuncCall(builder, loc, "getSparseTensorReaderDimSizes", memTp,
650  reader, EmitCInterface::On)
651  .getResult(0);
652  // Also convert the dim shapes values into dim sizes values, just in case
653  // subsequent clients need the values (DCE will remove unused).
654  for (Dimension d = 0; d < dimRank; d++) {
655  if (stt.isDynamicDim(d))
656  dimSizesValues[d] = builder.create<memref::LoadOp>(
657  loc, dimSizesBuffer, constantIndex(builder, loc, d));
658  }
659  }
660  return reader;
661 }
662 
664  OpBuilder &builder, Location loc, SparseTensorType stt,
665  ArrayRef<Value> dimSizesValues, Value dimSizesBuffer,
666  /*out*/ SmallVectorImpl<Value> &lvlSizesValues,
667  /*out*/ Value &dim2lvlBuffer,
668  /*out*/ Value &lvl2dimBuffer) {
669  const Dimension dimRank = stt.getDimRank();
670  const Level lvlRank = stt.getLvlRank();
671  lvlSizesValues.clear();
672  lvlSizesValues.reserve(lvlRank);
673  // For an identity mapping, the dim2lvl and lvl2dim mappings are
674  // identical as are dimSizes and lvlSizes, so buffers are reused
675  // as much as possible.
676  if (stt.isIdentity()) {
677  assert(dimRank == lvlRank);
678  SmallVector<Value> iotaValues;
679  iotaValues.reserve(lvlRank);
680  for (Level l = 0; l < lvlRank; l++) {
681  iotaValues.push_back(constantIndex(builder, loc, l));
682  lvlSizesValues.push_back(dimSizesValues[l]);
683  }
684  dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(builder, loc, iotaValues);
685  return dimSizesBuffer; // now lvlSizesBuffer
686  }
687  // Otherwise, some code needs to be generated to set up the buffers.
688  // This code deals with permutations as well as non-permutations that
689  // arise from rank changing blocking.
690  const auto dimToLvl = stt.getDimToLvl();
691  const auto lvlToDim = stt.getLvlToDim();
692  SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars
693  SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars
694  // Generate dim2lvl.
695  assert(lvlRank == dimToLvl.getNumResults());
696  for (Level l = 0; l < lvlRank; l++) {
697  AffineExpr exp = dimToLvl.getResult(l);
698  // We expect:
699  // (1) l = d
700  // (2) l = d / c
701  // (3) l = d % c
702  Dimension d = 0;
703  uint64_t cf = 0, cm = 0;
704  switch (exp.getKind()) {
705  case AffineExprKind::DimId: {
706  d = cast<AffineDimExpr>(exp).getPosition();
707  break;
708  }
710  auto floor = cast<AffineBinaryOpExpr>(exp);
711  d = cast<AffineDimExpr>(floor.getLHS()).getPosition();
712  cf = cast<AffineConstantExpr>(floor.getRHS()).getValue();
713  break;
714  }
715  case AffineExprKind::Mod: {
716  auto mod = cast<AffineBinaryOpExpr>(exp);
717  d = cast<AffineDimExpr>(mod.getLHS()).getPosition();
718  cm = cast<AffineConstantExpr>(mod.getRHS()).getValue();
719  break;
720  }
721  default:
722  llvm::report_fatal_error("unsupported dim2lvl in sparse tensor type");
723  }
724  dim2lvlValues[l] = constantIndex(builder, loc, encodeDim(d, cf, cm));
725  // Compute the level sizes.
726  // (1) l = d : size(d)
727  // (2) l = d / c : size(d) / c
728  // (3) l = d % c : c
729  Value lvlSz;
730  if (cm == 0) {
731  lvlSz = dimSizesValues[d];
732  if (cf != 0)
733  lvlSz = builder.create<arith::DivUIOp>(loc, lvlSz,
734  constantIndex(builder, loc, cf));
735  } else {
736  lvlSz = constantIndex(builder, loc, cm);
737  }
738  lvlSizesValues.push_back(lvlSz);
739  }
740  // Generate lvl2dim.
741  assert(dimRank == lvlToDim.getNumResults());
742  for (Dimension d = 0; d < dimRank; d++) {
743  AffineExpr exp = lvlToDim.getResult(d);
744  // We expect:
745  // (1) d = l
746  // (2) d = l' * c + l
747  Level l = 0, ll = 0;
748  uint64_t c = 0;
749  switch (exp.getKind()) {
750  case AffineExprKind::DimId: {
751  l = cast<AffineDimExpr>(exp).getPosition();
752  break;
753  }
754  case AffineExprKind::Add: {
755  // Always mul on lhs, symbol/constant on rhs.
756  auto add = cast<AffineBinaryOpExpr>(exp);
757  assert(add.getLHS().getKind() == AffineExprKind::Mul);
758  auto mul = cast<AffineBinaryOpExpr>(add.getLHS());
759  ll = cast<AffineDimExpr>(mul.getLHS()).getPosition();
760  c = cast<AffineConstantExpr>(mul.getRHS()).getValue();
761  l = cast<AffineDimExpr>(add.getRHS()).getPosition();
762  break;
763  }
764  default:
765  llvm::report_fatal_error("unsupported lvl2dim in sparse tensor type");
766  }
767  lvl2dimValues[d] = constantIndex(builder, loc, encodeLvl(l, c, ll));
768  }
769  // Return buffers.
770  dim2lvlBuffer = allocaBuffer(builder, loc, dim2lvlValues);
771  lvl2dimBuffer = allocaBuffer(builder, loc, lvl2dimValues);
772  return allocaBuffer(builder, loc, lvlSizesValues); // lvlSizesBuffer
773 }
#define CASE(ONAME, O)
#define MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DO)
Definition: Enums.h:63
#define MLIR_SPARSETENSOR_FOREVERY_V(DO)
Definition: Enums.h:96
static int64_t product(ArrayRef< int64_t > vals)
Base type for affine expression.
Definition: AffineExpr.h:69
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:27
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
unsigned getNumResults() const
Definition: AffineMap.cpp:388
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:542
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:394
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:450
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:52
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:59
bool isIndex() const
Definition: Types.cpp:56
U dyn_cast() const
Definition: Types.h:329
bool isF32() const
Definition: Types.cpp:51
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:113
bool isF16() const
Definition: Types.cpp:49
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
bool isBF16() const
Definition: Types.cpp:48
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:90
Value getValMemSize(OpBuilder &builder, Location loc) const
A wrapper around RankedTensorType, which has three goals:
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
Dimension getDimRank() const
Returns the dimension-rank.
AffineMap getLvlToDim() const
Returns the lvlToDiml mapping (or the null-map for the identity).
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasDynamicDimShape() const
Returns true if any dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
bool isDynamicDim(Dimension d) const
Returns true if the given dimension has dynamic size.
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:88
MPInt floor(const Fraction &f)
Definition: Fraction.h:74
TypedAttr getOneAttr(Builder &builder, Type tp)
Generates a 1-valued attribute of the given type.
Value genToCoordinatesBuffer(OpBuilder &builder, Location loc, Value tensor)
Infers the result type and generates ToCoordinatesBufferOp.
FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Returns a function reference (first hit also inserts into module).
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:361
void foreachInSparseConstant(OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, function_ref< void(ArrayRef< Value >, Value)> callback)
Iterate over a sparse constant, generates constantOp for value and coordinates.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:339
Value allocaBuffer(OpBuilder &builder, Location loc, ValueRange values)
Generates a temporary buffer, initializes it with the given contents, and returns it as type memref<?...
OverheadType posTypeEncoding(SparseTensorEncodingAttr enc)
Returns the OverheadType for position overhead storage.
OverheadType
Encoding of overhead types (both position overhead and coordinate overhead), for "overloading" @newSp...
Definition: Enums.h:51
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:35
OverheadType crdTypeEncoding(SparseTensorEncodingAttr enc)
Returns the OverheadType for coordinate overhead storage.
Value genToValues(OpBuilder &builder, Location loc, Value tensor)
Infers the result type and generates ToValuesOp.
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
Value genToPositions(OpBuilder &builder, Location loc, Value tensor, Level lvl)
Infers the result type and generates ToPositionsOp.
OverheadType overheadTypeEncoding(unsigned width)
Converts an overhead storage bitwidth to its internal type-encoding.
MemRefType get1DMemRefType(Type etp, bool withLayout)
Generates a 1D MemRefType with a dynamic size.
Definition: CodegenUtils.h:234
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:42
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
PrimaryType
Encoding of the elemental type, for "overloading" @newSparseTensor.
Definition: Enums.h:82
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:74
PrimaryType primaryTypeEncoding(Type elemTp)
Converts a primary storage type to its internal type-encoding.
Operation * getTop(Operation *op)
Scans to top of generated loop.
Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor, Dimension dim)
Generates code to retrieve the slice slice for the sparse tensor slice, return a constant if the offs...
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, Value s)
Generates a pointer/index load from the sparse storage scheme.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor)
void deallocDenseTensor(OpBuilder &builder, Location loc, Value buffer)
Generates code to deallocate a dense buffer.
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
constexpr uint64_t encodeLvl(uint64_t i, uint64_t c, uint64_t ii)
Definition: Enums.h:463
SmallVector< Value > loadAll(OpBuilder &builder, Location loc, size_t size, Value mem, size_t offsetIdx=0, Value offsetVal=Value())
Loads size-many values from the memref, which must have rank-1 and size greater-or-equal to size.
constexpr uint64_t encodeDim(uint64_t i, uint64_t cf, uint64_t cm)
Bit manipulations for affine encoding.
Definition: Enums.h:451
void genReshapeDstShape(OpBuilder &builder, Location loc, SmallVectorImpl< Value > &dstShape, ArrayRef< Value > srcShape, ArrayRef< Size > staticDstShape, ArrayRef< ReassociationIndices > reassociation)
Computes the shape of destination tensor of a reshape operator.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
void reshapeCvs(OpBuilder &builder, Location loc, ArrayRef< ReassociationIndices > reassociation, ValueRange srcSizes, ValueRange srcCvs, ValueRange dstSizes, SmallVectorImpl< Value > &dstCvs)
Reshape coordinates during a reshaping operation.
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
Value genToCoordinates(OpBuilder &builder, Location loc, Value tensor, Level lvl)
Infers the result type and generates ToCoordinatesOp.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor, Dimension dim)
Generates code to retrieve the slice offset for the sparse tensor slice, return a constant if the off...
Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc, Type elemTp)
Generates a constant of the internal type-encoding for primary storage.
Definition: CodegenUtils.h:417
void sizesFromSrc(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, Value src)
Populates given sizes array from dense tensor or sparse tensor constant.
Type getOverheadType(Builder &builder, OverheadType ot)
Converts the internal type-encoding for overhead storage to an mlir::Type.
Value genValMemSize(OpBuilder &builder, Location loc, Value tensor)
Generates code to retrieve the values size for the sparse tensor.
EmitCInterface
Shorthand aliases for the emitCInterface argument to getFunc(), createFuncCall(), and replaceOpWithFu...
Definition: CodegenUtils.h:36
Value allocDenseTensor(OpBuilder &builder, Location loc, RankedTensorType tensorTp, ValueRange sizes)
Generates code to allocate a buffer of the given type, and zero initialize it.
void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs, size_t offsetIdx=0, Value offsetVal=Value())
Stores all the values of vs into the memref mem, which must have rank-1 and size greater-or-equal to ...
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:169
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:494
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
Definition: MathExtras.h:45
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.