MLIR  22.0.0git
CodegenUtils.cpp
Go to the documentation of this file.
1 //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "CodegenUtils.h"
10 
16 #include "mlir/IR/Types.h"
17 #include "mlir/IR/Value.h"
18 #include <optional>
19 
20 using namespace mlir;
21 using namespace mlir::sparse_tensor;
22 
23 //===----------------------------------------------------------------------===//
24 // ExecutionEngine/SparseTensorUtils helper functions.
25 //===----------------------------------------------------------------------===//
26 
28  switch (width) {
29  case 64:
30  return OverheadType::kU64;
31  case 32:
32  return OverheadType::kU32;
33  case 16:
34  return OverheadType::kU16;
35  case 8:
36  return OverheadType::kU8;
37  case 0:
38  return OverheadType::kIndex;
39  }
40  llvm_unreachable("Unsupported overhead bitwidth");
41 }
42 
44  if (tp.isIndex())
45  return OverheadType::kIndex;
46  if (auto intTp = dyn_cast<IntegerType>(tp))
47  return overheadTypeEncoding(intTp.getWidth());
48  llvm_unreachable("Unknown overhead type");
49 }
50 
52  switch (ot) {
54  return builder.getIndexType();
55  case OverheadType::kU64:
56  return builder.getIntegerType(64);
57  case OverheadType::kU32:
58  return builder.getIntegerType(32);
59  case OverheadType::kU16:
60  return builder.getIntegerType(16);
61  case OverheadType::kU8:
62  return builder.getIntegerType(8);
63  }
64  llvm_unreachable("Unknown OverheadType");
65 }
66 
68 mlir::sparse_tensor::posTypeEncoding(SparseTensorEncodingAttr enc) {
69  return overheadTypeEncoding(enc.getPosWidth());
70 }
71 
73 mlir::sparse_tensor::crdTypeEncoding(SparseTensorEncodingAttr enc) {
74  return overheadTypeEncoding(enc.getCrdWidth());
75 }
76 
77 // TODO: we ought to add some `static_assert` tests to ensure that the
78 // `STEA::get{Pos,Crd}Type` methods agree with `getOverheadType(builder,
79 // {pos,crd}OverheadTypeEncoding(enc))`
80 
81 // TODO: Adjust the naming convention for the constructors of
82 // `OverheadType` so we can use the `MLIR_SPARSETENSOR_FOREVERY_O` x-macro
83 // here instead of `MLIR_SPARSETENSOR_FOREVERY_FIXED_O`; to further reduce
84 // the possibility of typo bugs or things getting out of sync.
86  switch (ot) {
88  return "0";
89 #define CASE(ONAME, O) \
90  case OverheadType::kU##ONAME: \
91  return #ONAME;
93 #undef CASE
94  }
95  llvm_unreachable("Unknown OverheadType");
96 }
97 
100 }
101 
103  if (elemTp.isF64())
104  return PrimaryType::kF64;
105  if (elemTp.isF32())
106  return PrimaryType::kF32;
107  if (elemTp.isF16())
108  return PrimaryType::kF16;
109  if (elemTp.isBF16())
110  return PrimaryType::kBF16;
111  if (elemTp.isInteger(64))
112  return PrimaryType::kI64;
113  if (elemTp.isInteger(32))
114  return PrimaryType::kI32;
115  if (elemTp.isInteger(16))
116  return PrimaryType::kI16;
117  if (elemTp.isInteger(8))
118  return PrimaryType::kI8;
119  if (auto complexTp = dyn_cast<ComplexType>(elemTp)) {
120  auto complexEltTp = complexTp.getElementType();
121  if (complexEltTp.isF64())
122  return PrimaryType::kC64;
123  if (complexEltTp.isF32())
124  return PrimaryType::kC32;
125  }
126  llvm_unreachable("Unknown primary type");
127 }
128 
130  switch (pt) {
131 #define CASE(VNAME, V) \
132  case PrimaryType::k##VNAME: \
133  return #VNAME;
135 #undef CASE
136  }
137  llvm_unreachable("Unknown PrimaryType");
138 }
139 
142 }
143 
144 //===----------------------------------------------------------------------===//
145 // Misc code generators.
146 //===----------------------------------------------------------------------===//
147 
149  Type dstTp) {
150  const Type srcTp = value.getType();
151  if (srcTp == dstTp)
152  return value;
153 
154  // int <=> index
155  if (isa<IndexType>(srcTp) || isa<IndexType>(dstTp))
156  return arith::IndexCastOp::create(builder, loc, dstTp, value);
157 
158  const auto srcIntTp = dyn_cast_or_null<IntegerType>(srcTp);
159  const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false;
160  return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
161 }
162 
164  Value elem, Type dstTp) {
165  if (auto rtp = dyn_cast<RankedTensorType>(dstTp)) {
166  // Scalars can only be converted to 0-ranked tensors.
167  assert(rtp.getRank() == 0);
168  elem = sparse_tensor::genCast(builder, loc, elem, rtp.getElementType());
169  return tensor::FromElementsOp::create(builder, loc, rtp, elem);
170  }
171  return sparse_tensor::genCast(builder, loc, elem, dstTp);
172 }
173 
175  ValueRange s) {
176  Value load = memref::LoadOp::create(builder, loc, mem, s);
177  if (!isa<IndexType>(load.getType())) {
178  if (load.getType().getIntOrFloatBitWidth() < 64)
179  load = arith::ExtUIOp::create(builder, loc, builder.getI64Type(), load);
180  load =
181  arith::IndexCastOp::create(builder, loc, builder.getIndexType(), load);
182  }
183  return load;
184 }
185 
186 mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
187  if (isa<FloatType>(tp))
188  return builder.getFloatAttr(tp, 1.0);
189  if (isa<IndexType>(tp))
190  return builder.getIndexAttr(1);
191  if (auto intTp = dyn_cast<IntegerType>(tp))
192  return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1));
193  if (isa<RankedTensorType, VectorType>(tp)) {
194  auto shapedTp = cast<ShapedType>(tp);
195  if (auto one = getOneAttr(builder, shapedTp.getElementType()))
196  return DenseElementsAttr::get(shapedTp, one);
197  }
198  llvm_unreachable("Unsupported attribute type");
199 }
200 
202  Value v) {
203  Type tp = v.getType();
204  Value zero = constantZero(builder, loc, tp);
205  if (isa<FloatType>(tp))
206  return arith::CmpFOp::create(builder, loc, arith::CmpFPredicate::UNE, v,
207  zero);
208  if (tp.isIntOrIndex())
209  return arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ne, v,
210  zero);
211  if (isa<ComplexType>(tp))
212  return complex::NotEqualOp::create(builder, loc, v, zero);
213  llvm_unreachable("Non-numeric type");
214 }
215 
217  OpBuilder &builder, Location loc, SmallVectorImpl<Value> &dstShape,
218  ArrayRef<Value> srcShape, ArrayRef<Size> staticDstShape,
219  ArrayRef<ReassociationIndices> reassociation) {
220  // Collapse shape.
221  if (reassociation.size() < srcShape.size()) {
222  unsigned start = 0;
223  for (const auto &map : llvm::enumerate(reassociation)) {
224  auto dstDim = constantIndex(builder, loc, 1);
225  for (unsigned i = start; i < start + map.value().size(); i++) {
226  dstDim = arith::MulIOp::create(builder, loc, dstDim, srcShape[i]);
227  }
228  dstShape.push_back(dstDim);
229  start = start + map.value().size();
230  }
231  assert(start == srcShape.size());
232  return;
233  }
234 
235  // Expand shape.
236  assert(reassociation.size() == srcShape.size());
237  unsigned start = 0;
238  // Expand the i-th dimension in srcShape.
239  for (unsigned i = 0, size = srcShape.size(); i < size; i++) {
240  const auto &map = reassociation[i];
241  auto srcDim = srcShape[i];
242  // Iterate through dimensions expanded from the i-th dimension.
243  for (unsigned j = start; j < start + map.size(); j++) {
244  // There can be only one dynamic sized dimension among dimensions
245  // expanded from the i-th dimension in srcShape.
246  // For example, if srcDim = 8, then the expanded shape could be <2x?x2>,
247  // but not <2x?x?>.
248  if (staticDstShape[j] == ShapedType::kDynamic) {
249  // The expanded dimension has dynamic size. We compute the dimension
250  // by dividing srcDim by the product of the static dimensions.
251  Size product = 1;
252  for (unsigned k = start; k < start + map.size(); k++) {
253  if (staticDstShape[k] != ShapedType::kDynamic) {
254  product *= staticDstShape[k];
255  }
256  }
257  // Compute the dynamic dimension size.
258  Value productVal = constantIndex(builder, loc, product);
259  Value dynamicSize =
260  arith::DivUIOp::create(builder, loc, srcDim, productVal);
261  dstShape.push_back(dynamicSize);
262  } else {
263  // The expanded dimension is statically known.
264  dstShape.push_back(constantIndex(builder, loc, staticDstShape[j]));
265  }
266  }
267  start = start + map.size();
268  }
269  assert(start == staticDstShape.size());
270 }
271 
273  OpBuilder &builder, Location loc,
274  ArrayRef<ReassociationIndices> reassociation, // NOLINT
275  ValueRange srcSizes, ValueRange srcCvs, // NOLINT
276  ValueRange dstSizes, SmallVectorImpl<Value> &dstCvs) {
277  const unsigned srcRank = srcSizes.size();
278  const unsigned dstRank = dstSizes.size();
279  assert(srcRank == srcCvs.size() && "Source rank mismatch");
280  const bool isCollapse = srcRank > dstRank;
281  const ValueRange sizes = isCollapse ? srcSizes : dstSizes;
282  // Iterate over reassociation map.
283  unsigned i = 0;
284  unsigned start = 0;
285  for (const auto &map : llvm::enumerate(reassociation)) {
286  // Prepare strides information in dimension slice.
287  Value linear = constantIndex(builder, loc, 1);
288  for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
289  linear = arith::MulIOp::create(builder, loc, linear, sizes[j]);
290  }
291  // Start expansion.
292  Value val;
293  if (!isCollapse)
294  val = srcCvs[i];
295  // Iterate over dimension slice.
296  for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
297  linear = arith::DivUIOp::create(builder, loc, linear, sizes[j]);
298  if (isCollapse) {
299  const Value mul =
300  arith::MulIOp::create(builder, loc, srcCvs[j], linear);
301  val = val ? arith::AddIOp::create(builder, loc, val, mul) : mul;
302  } else {
303  const Value old = val;
304  val = arith::DivUIOp::create(builder, loc, val, linear);
305  assert(dstCvs.size() == j);
306  dstCvs.push_back(val);
307  val = arith::RemUIOp::create(builder, loc, old, linear);
308  }
309  }
310  // Finalize collapse.
311  if (isCollapse) {
312  assert(dstCvs.size() == i);
313  dstCvs.push_back(val);
314  }
315  start += map.value().size();
316  i++;
317  }
318  assert(dstCvs.size() == dstRank);
319 }
320 
321 FlatSymbolRefAttr mlir::sparse_tensor::getFunc(ModuleOp module, StringRef name,
322  TypeRange resultType,
323  ValueRange operands,
324  EmitCInterface emitCInterface) {
325  MLIRContext *context = module.getContext();
326  auto result = SymbolRefAttr::get(context, name);
327  auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
328  if (!func) {
329  OpBuilder moduleBuilder(module.getBodyRegion());
330  func = func::FuncOp::create(
331  moduleBuilder, module.getLoc(), name,
332  FunctionType::get(context, operands.getTypes(), resultType));
333  func.setPrivate();
334  if (static_cast<bool>(emitCInterface))
335  func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
336  UnitAttr::get(context));
337  }
338  return result;
339 }
340 
342  OpBuilder &builder, Location loc, StringRef name, TypeRange resultType,
343  ValueRange operands, EmitCInterface emitCInterface) {
344  auto module = builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
345  FlatSymbolRefAttr fn =
346  getFunc(module, name, resultType, operands, emitCInterface);
347  return func::CallOp::create(builder, loc, resultType, fn, operands);
348 }
349 
351  return LLVM::LLVMPointerType::get(ctx);
352 }
353 
355  return getOpaquePointerType(builder.getContext());
356 }
357 
359  unsigned sz, Type tp, bool staticShape) {
360  if (staticShape) {
361  auto memTp = MemRefType::get({sz}, tp);
362  return memref::AllocaOp::create(builder, loc, memTp);
363  }
364  return genAlloca(builder, loc, constantIndex(builder, loc, sz), tp);
365 }
366 
368  Type tp) {
369  auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
370  return memref::AllocaOp::create(builder, loc, memTp, ValueRange{sz});
371 }
372 
374  Type tp) {
375  return memref::AllocaOp::create(builder, loc, MemRefType::get({}, tp));
376 }
377 
379  ValueRange values) {
380  const unsigned sz = values.size();
381  assert(sz >= 1);
382  Value buffer = genAlloca(builder, loc, sz, values[0].getType());
383  for (unsigned i = 0; i < sz; i++) {
384  Value idx = constantIndex(builder, loc, i);
385  memref::StoreOp::create(builder, loc, values[i], buffer, idx);
386  }
387  return buffer;
388 }
389 
391  RankedTensorType tensorTp,
392  ValueRange sizes) {
393  Type elemTp = tensorTp.getElementType();
394  auto shape = tensorTp.getShape();
395  auto memTp = MemRefType::get(shape, elemTp);
396  SmallVector<Value> dynamicSizes;
397  for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) {
398  if (shape[i] == ShapedType::kDynamic)
399  dynamicSizes.push_back(sizes[i]);
400  }
401  Value mem = memref::AllocOp::create(builder, loc, memTp, dynamicSizes);
402  Value zero = constantZero(builder, loc, elemTp);
403  linalg::FillOp::create(builder, loc, ValueRange{zero}, ValueRange{mem});
404  return mem;
405 }
406 
408  Value buffer) {
409  memref::DeallocOp::create(builder, loc, buffer);
410 }
411 
413  SmallVectorImpl<Value> &sizes,
414  Location loc, Value src) {
415  const Dimension dimRank = getSparseTensorType(src).getDimRank();
416  for (Dimension d = 0; d < dimRank; d++)
417  sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, d));
418 }
419 
421  for (; isa<scf::ForOp>(op->getParentOp()) ||
422  isa<scf::WhileOp>(op->getParentOp()) ||
423  isa<scf::ParallelOp>(op->getParentOp()) ||
424  isa<scf::IfOp>(op->getParentOp());
425  op = op->getParentOp())
426  ;
427  return op;
428 }
429 
431  OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
432  function_ref<void(ArrayRef<Value>, Value)> callback) {
433  if (!order)
434  order = builder.getMultiDimIdentityMap(attr.getType().getRank());
435 
436  auto stt = SparseTensorType(getRankedTensorType(attr));
437  const Dimension dimRank = stt.getDimRank();
438  const auto coordinates = attr.getIndices().getValues<IntegerAttr>();
439  const auto values = attr.getValues().getValues<Attribute>();
440 
441  // This is like the `Element<V>` class in the runtime library, but for
442  // MLIR attributes. In the future we may want to move this out into
443  // a proper class definition to help improve code legibility (e.g.,
444  // `first` -> `coords`, `second` -> `value`) as well as being able
445  // to factor out analogues of `ElementLT<V>` for the sort below, etc.
446  using ElementAttr = std::pair<SmallVector<IntegerAttr>, Attribute>;
447 
448  // Construct the COO from the SparseElementsAttr.
450  for (size_t i = 0, nse = values.size(); i < nse; i++) {
451  elems.emplace_back();
452  elems.back().second = values[i];
453  auto &coords = elems.back().first;
454  coords.reserve(dimRank);
455  for (Dimension d = 0; d < dimRank; d++)
456  coords.push_back(coordinates[i * dimRank + d]);
457  }
458 
459  // Sorts the sparse element attribute based on coordinates.
460  llvm::sort(elems, [order](const ElementAttr &lhs, const ElementAttr &rhs) {
461  if (std::addressof(lhs) == std::addressof(rhs))
462  return false;
463 
464  auto lhsCoords = llvm::map_to_vector(
465  lhs.first, [](IntegerAttr i) { return i.getInt(); });
466  auto rhsCoords = llvm::map_to_vector(
467  rhs.first, [](IntegerAttr i) { return i.getInt(); });
468 
469  SmallVector<int64_t, 4> lhsLvlCrds = order.compose(lhsCoords);
470  SmallVector<int64_t, 4> rhsLvlCrds = order.compose(rhsCoords);
471  // Sort the element based on the lvl coordinates.
472  for (Level l = 0; l < order.getNumResults(); l++) {
473  if (lhsLvlCrds[l] == rhsLvlCrds[l])
474  continue;
475  return lhsLvlCrds[l] < rhsLvlCrds[l];
476  }
477  llvm_unreachable("no equal coordinate in sparse element attr");
478  });
479 
480  SmallVector<Value> cvs;
481  cvs.reserve(dimRank);
482  for (size_t i = 0, nse = values.size(); i < nse; i++) {
483  // Remap coordinates.
484  cvs.clear();
485  for (Dimension d = 0; d < dimRank; d++) {
486  auto crd = elems[i].first[d].getInt();
487  cvs.push_back(arith::ConstantIndexOp::create(builder, loc, crd));
488  }
489  // Remap value.
490  Value val;
491  if (isa<ComplexType>(attr.getElementType())) {
492  auto valAttr = cast<ArrayAttr>(elems[i].second);
493  val = complex::ConstantOp::create(builder, loc, attr.getElementType(),
494  valAttr);
495  } else {
496  auto valAttr = cast<TypedAttr>(elems[i].second);
497  val = arith::ConstantOp::create(builder, loc, valAttr);
498  }
499  assert(val);
500  callback(cvs, val);
501  }
502 }
503 
505  size_t size, Value mem,
506  size_t offsetIdx, Value offsetVal) {
507 #ifndef NDEBUG
508  const auto memTp = cast<MemRefType>(mem.getType());
509  assert(memTp.getRank() == 1);
510  const Size memSh = memTp.getDimSize(0);
511  assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(size));
512  assert(offsetIdx == 0 || offsetIdx < size);
513 #endif // NDEBUG
515  vs.reserve(size);
516  for (unsigned i = 0; i < size; i++) {
517  Value v = memref::LoadOp::create(builder, loc, mem,
518  constantIndex(builder, loc, i));
519  if (i == offsetIdx && offsetVal)
520  v = arith::AddIOp::create(builder, loc, v, offsetVal);
521  vs.push_back(v);
522  }
523  return vs;
524 }
525 
527  ValueRange vs, size_t offsetIdx, Value offsetVal) {
528 #ifndef NDEBUG
529  const size_t vsize = vs.size();
530  const auto memTp = cast<MemRefType>(mem.getType());
531  assert(memTp.getRank() == 1);
532  const Size memSh = memTp.getDimSize(0);
533  assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(vsize));
534  assert(offsetIdx == 0 || offsetIdx < vsize);
535 #endif // NDEBUG
536  for (const auto &v : llvm::enumerate(vs)) {
537  const Value w =
538  (offsetIdx == v.index() && offsetVal)
539  ? arith::AddIOp::create(builder, loc, v.value(), offsetVal)
540  : v.value();
541  memref::StoreOp::create(builder, loc, w, mem,
542  constantIndex(builder, loc, v.index()));
543  }
544 }
545 
548  auto tTp = llvm::cast<TensorType>(tensor.getType());
549  auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType());
550  return cast<TypedValue<BaseMemRefType>>(
551  bufferization::ToBufferOp::create(builder, loc, mTp, tensor).getResult());
552 }
553 
555  Value tensor, Dimension dim) {
556  auto enc = getSparseTensorEncoding(tensor.getType());
557  assert(enc && enc.isSlice());
558  std::optional<unsigned> offset = enc.getStaticDimSliceOffset(dim);
559  if (offset.has_value())
560  return constantIndex(builder, loc, *offset);
561  return ToSliceOffsetOp::create(builder, loc, tensor, APInt(64, dim));
562 }
563 
565  Value tensor, Dimension dim) {
566  auto enc = getSparseTensorEncoding(tensor.getType());
567  assert(enc && enc.isSlice());
568  std::optional<unsigned> stride = enc.getStaticDimSliceStride(dim);
569  if (stride.has_value())
570  return constantIndex(builder, loc, *stride);
571  return ToSliceStrideOp::create(builder, loc, tensor, APInt(64, dim));
572 }
573 
575  SparseTensorType stt, Value tensor,
576  /*out*/ SmallVectorImpl<Value> &dimSizesValues,
577  /*out*/ Value &dimSizesBuffer) {
578  // Construct the dimension **shapes** buffer. The buffer contains the static
579  // size per dimension, or otherwise a zero for a dynamic size.
580  Dimension dimRank = stt.getDimRank();
581  dimSizesValues.clear();
582  dimSizesValues.reserve(dimRank);
583  for (const Size sz : stt.getDimShape()) {
584  const auto s = ShapedType::isDynamic(sz) ? 0 : sz;
585  dimSizesValues.push_back(constantIndex(builder, loc, s));
586  }
587  Value dimShapesBuffer = allocaBuffer(builder, loc, dimSizesValues);
588  // Create the `CheckedSparseTensorReader`. This reader performs a
589  // consistency check on the static sizes, but accepts any size
590  // of each dimension with a dynamic size.
591  Type opaqueTp = getOpaquePointerType(builder);
592  Type eltTp = stt.getElementType();
593  Value valTp = constantPrimaryTypeEncoding(builder, loc, eltTp);
594  Value reader =
595  createFuncCall(builder, loc, "createCheckedSparseTensorReader", opaqueTp,
596  {tensor, dimShapesBuffer, valTp}, EmitCInterface::On)
597  .getResult(0);
598  // For static shapes, the shape buffer can be used right away. For dynamic
599  // shapes, use the information from the reader to construct a buffer that
600  // supplies the actual size for each dynamic dimension.
601  dimSizesBuffer = dimShapesBuffer;
602  if (stt.hasDynamicDimShape()) {
603  Type indexTp = builder.getIndexType();
604  auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp);
605  dimSizesBuffer =
606  createFuncCall(builder, loc, "getSparseTensorReaderDimSizes", memTp,
607  reader, EmitCInterface::On)
608  .getResult(0);
609  // Also convert the dim shapes values into dim sizes values, just in case
610  // subsequent clients need the values (DCE will remove unused).
611  for (Dimension d = 0; d < dimRank; d++) {
612  if (stt.isDynamicDim(d))
613  dimSizesValues[d] = memref::LoadOp::create(
614  builder, loc, dimSizesBuffer, constantIndex(builder, loc, d));
615  }
616  }
617  return reader;
618 }
619 
621  OpBuilder &builder, Location loc, SparseTensorType stt,
622  ArrayRef<Value> dimSizesValues, Value dimSizesBuffer,
623  /*out*/ SmallVectorImpl<Value> &lvlSizesValues,
624  /*out*/ Value &dim2lvlBuffer,
625  /*out*/ Value &lvl2dimBuffer) {
626  const Dimension dimRank = stt.getDimRank();
627  const Level lvlRank = stt.getLvlRank();
628  lvlSizesValues.clear();
629  lvlSizesValues.reserve(lvlRank);
630  // For an identity mapping, the dim2lvl and lvl2dim mappings are
631  // identical as are dimSizes and lvlSizes, so buffers are reused
632  // as much as possible.
633  if (stt.isIdentity()) {
634  assert(dimRank == lvlRank);
635  SmallVector<Value> iotaValues;
636  iotaValues.reserve(lvlRank);
637  for (Level l = 0; l < lvlRank; l++) {
638  iotaValues.push_back(constantIndex(builder, loc, l));
639  lvlSizesValues.push_back(dimSizesValues[l]);
640  }
641  dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(builder, loc, iotaValues);
642  return dimSizesBuffer; // now lvlSizesBuffer
643  }
644  // Otherwise, some code needs to be generated to set up the buffers.
645  // This code deals with permutations as well as non-permutations that
646  // arise from rank changing blocking.
647  const auto dimToLvl = stt.getDimToLvl();
648  const auto lvlToDim = stt.getLvlToDim();
649  SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars
650  SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars
651  // Generate dim2lvl.
652  assert(lvlRank == dimToLvl.getNumResults());
653  for (Level l = 0; l < lvlRank; l++) {
654  AffineExpr exp = dimToLvl.getResult(l);
655  // We expect:
656  // (1) l = d
657  // (2) l = d / c
658  // (3) l = d % c
659  Dimension d = 0;
660  uint64_t cf = 0, cm = 0;
661  switch (exp.getKind()) {
662  case AffineExprKind::DimId: {
663  d = cast<AffineDimExpr>(exp).getPosition();
664  break;
665  }
667  auto floor = cast<AffineBinaryOpExpr>(exp);
668  d = cast<AffineDimExpr>(floor.getLHS()).getPosition();
669  cf = cast<AffineConstantExpr>(floor.getRHS()).getValue();
670  break;
671  }
672  case AffineExprKind::Mod: {
673  auto mod = cast<AffineBinaryOpExpr>(exp);
674  d = cast<AffineDimExpr>(mod.getLHS()).getPosition();
675  cm = cast<AffineConstantExpr>(mod.getRHS()).getValue();
676  break;
677  }
678  default:
679  llvm::report_fatal_error("unsupported dim2lvl in sparse tensor type");
680  }
681  dim2lvlValues[l] = constantIndex(builder, loc, encodeDim(d, cf, cm));
682  // Compute the level sizes.
683  // (1) l = d : size(d)
684  // (2) l = d / c : size(d) / c
685  // (3) l = d % c : c
686  Value lvlSz;
687  if (cm == 0) {
688  lvlSz = dimSizesValues[d];
689  if (cf != 0)
690  lvlSz = arith::DivUIOp::create(builder, loc, lvlSz,
691  constantIndex(builder, loc, cf));
692  } else {
693  lvlSz = constantIndex(builder, loc, cm);
694  }
695  lvlSizesValues.push_back(lvlSz);
696  }
697  // Generate lvl2dim.
698  assert(dimRank == lvlToDim.getNumResults());
699  for (Dimension d = 0; d < dimRank; d++) {
700  AffineExpr exp = lvlToDim.getResult(d);
701  // We expect:
702  // (1) d = l
703  // (2) d = l' * c + l
704  Level l = 0, ll = 0;
705  uint64_t c = 0;
706  switch (exp.getKind()) {
707  case AffineExprKind::DimId: {
708  l = cast<AffineDimExpr>(exp).getPosition();
709  break;
710  }
711  case AffineExprKind::Add: {
712  // Always mul on lhs, symbol/constant on rhs.
713  auto add = cast<AffineBinaryOpExpr>(exp);
714  assert(add.getLHS().getKind() == AffineExprKind::Mul);
715  auto mul = cast<AffineBinaryOpExpr>(add.getLHS());
716  ll = cast<AffineDimExpr>(mul.getLHS()).getPosition();
717  c = cast<AffineConstantExpr>(mul.getRHS()).getValue();
718  l = cast<AffineDimExpr>(add.getRHS()).getPosition();
719  break;
720  }
721  default:
722  llvm::report_fatal_error("unsupported lvl2dim in sparse tensor type");
723  }
724  lvl2dimValues[d] = constantIndex(builder, loc, encodeLvl(l, c, ll));
725  }
726  // Return buffers.
727  dim2lvlBuffer = allocaBuffer(builder, loc, dim2lvlValues);
728  lvl2dimBuffer = allocaBuffer(builder, loc, lvl2dimValues);
729  return allocaBuffer(builder, loc, lvlSizesValues); // lvlSizesBuffer
730 }
#define CASE(ONAME, O)
#define MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DO)
Definition: Enums.h:63
#define MLIR_SPARSETENSOR_FOREVERY_V(DO)
Definition: Enums.h:96
static int64_t product(ArrayRef< int64_t > vals)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:33
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:382
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:50
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:205
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:446
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
bool isIndex() const
Definition: Types.cpp:54
bool isF32() const
Definition: Types.cpp:40
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:112
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
A wrapper around RankedTensorType, which has three goals:
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
Dimension getDimRank() const
Returns the dimension-rank.
AffineMap getLvlToDim() const
Returns the lvlToDiml mapping (or the null-map for the identity).
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasDynamicDimShape() const
Returns true if any dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
bool isDynamicDim(Dimension d) const
Returns true if the given dimension has dynamic size.
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:95
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:77
detail::LazyTextBuild add(const char *fmt, Ts &&...ts)
Create a Remark with llvm::formatv formatting.
Definition: Remarks.h:463
TypedAttr getOneAttr(Builder &builder, Type tp)
Generates a 1-valued attribute of the given type.
FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Returns a function reference (first hit also inserts into module).
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
void foreachInSparseConstant(OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, function_ref< void(ArrayRef< Value >, Value)> callback)
Iterate over a sparse constant, generates constantOp for value and coordinates.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:309
Value allocaBuffer(OpBuilder &builder, Location loc, ValueRange values)
Generates a temporary buffer, initializes it with the given contents, and returns it as type memref<?...
OverheadType posTypeEncoding(SparseTensorEncodingAttr enc)
Returns the OverheadType for position overhead storage.
OverheadType
Encoding of overhead types (both position overhead and coordinate overhead), for "overloading" @newSp...
Definition: Enums.h:51
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:39
OverheadType crdTypeEncoding(SparseTensorEncodingAttr enc)
Returns the OverheadType for coordinate overhead storage.
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
OverheadType overheadTypeEncoding(unsigned width)
Converts an overhead storage bitwidth to its internal type-encoding.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:46
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
PrimaryType
Encoding of the elemental type, for "overloading" @newSparseTensor.
Definition: Enums.h:82
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:160
PrimaryType primaryTypeEncoding(Type elemTp)
Converts a primary storage type to its internal type-encoding.
Operation * getTop(Operation *op)
Scans to top of generated loop.
Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor, Dimension dim)
Generates code to retrieve the slice slice for the sparse tensor slice, return a constant if the offs...
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
void deallocDenseTensor(OpBuilder &builder, Location loc, Value buffer)
Generates code to deallocate a dense buffer.
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
constexpr uint64_t encodeLvl(uint64_t i, uint64_t c, uint64_t ii)
Definition: Enums.h:483
SmallVector< Value > loadAll(OpBuilder &builder, Location loc, size_t size, Value mem, size_t offsetIdx=0, Value offsetVal=Value())
Loads size-many values from the memref, which must have rank-1 and size greater-or-equal to size.
constexpr uint64_t encodeDim(uint64_t i, uint64_t cf, uint64_t cm)
Bit manipulations for affine encoding.
Definition: Enums.h:471
void genReshapeDstShape(OpBuilder &builder, Location loc, SmallVectorImpl< Value > &dstShape, ArrayRef< Value > srcShape, ArrayRef< Size > staticDstShape, ArrayRef< ReassociationIndices > reassociation)
Computes the shape of destination tensor of a reshape operator.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
void reshapeCvs(OpBuilder &builder, Location loc, ArrayRef< ReassociationIndices > reassociation, ValueRange srcSizes, ValueRange srcCvs, ValueRange dstSizes, SmallVectorImpl< Value > &dstCvs)
Reshape coordinates during a reshaping operation.
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor, Dimension dim)
Generates code to retrieve the slice offset for the sparse tensor slice, return a constant if the off...
Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc, Type elemTp)
Generates a constant of the internal type-encoding for primary storage.
Definition: CodegenUtils.h:387
void sizesFromSrc(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, Value src)
Populates given sizes array from dense tensor or sparse tensor constant.
Type getOverheadType(Builder &builder, OverheadType ot)
Converts the internal type-encoding for overhead storage to an mlir::Type.
EmitCInterface
Shorthand aliases for the emitCInterface argument to getFunc(), createFuncCall(), and replaceOpWithFu...
Definition: CodegenUtils.h:36
Value allocDenseTensor(OpBuilder &builder, Location loc, RankedTensorType tensorTp, ValueRange sizes)
Generates code to allocate a buffer of the given type, and zero initialize it.
void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs, size_t offsetIdx=0, Value offsetVal=Value())
Stores all the values of vs into the memref mem, which must have rank-1 and size greater-or-equal to ...
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:238
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.