Is where any way to do transport layer serialization after broadcast

Hi, all! My overall task is to provide a custom public-key authorized encryption (with Tweet NaCl library) on top of phoenix websockets. After some research i’ve found that the best way is to implement custom Phoenix.Transports.Serializer. But the main problem now is how to encrypt broadcast messages - I need to encrypt it for each client with its own public key. May be I’ve choose a wrong layer for my task? Any suggestions? Thanks!

After a number of tries I’ve found a solution. I think that this is not a solution really but a workaround, because it depends on current PhoenixSocket implementation (v 1.4). But it works. The main idea is to override Phoenix.Socket behaviour like this:

defmodule ProjectWeb.UserSocket do
  alias Phoenix.Socket
  alias Absinthe.Phoenix.Socket, as: AbsintheSocket
  import Socket
  @behavior Socket
  @before_compile Socket
  Module.register_attribute(__MODULE__, :phoenix_channels, accumulate: true)
  @behaviour Socket.Transport

  use AbsintheSocket, schema: ProjectWeb.Schema

  require Logger

  # proxy unhandled callbacks to Phoenix.Socket
  def child_spec(opts), do: Socket.__child_spec__(__MODULE__, opts)
  def terminate(reason, state), do: Socket.__terminate__(reason, state)
  def init(state), do: Socket.__init__(state)
  def connect(map), do: Socket.__connect__(__MODULE__, map, :info)

  # handle input messages from client
  def handle_in({payload, opts}, {state, %{assigns: %{client_key: key}} = socket}) do
    with {:ok, decrypted} <- Crypto.decrypt(:web_api, payload, key) do
      {decrypted, opts}
      |> Socket.__in__({state, socket})
      |> encrypt()
    else
      :error -> {:stop, {:shutdown, :closed}, {state, socket}}
    end
  end

  # handle output messages to client
  def handle_info(message, state) do
    message
    |> Socket.__info__(state)
    |> encrypt()
  end

  # on connect we grap client public key and salt from params
  def connect(%{"key" => key, "salt" => salt}, socket) do
    with {:ok, decoded_key} <- Base.decode64(key),
         {:ok, decoded_salt} <- Base.decode64(salt) do
      socket = AbsintheSocket.put_options(socket, context: %{current_account: %{id: key}})
      {:ok, Socket.assign(socket, :client_key, {decoded_salt, decoded_key})}
    else
      {:error, reason} ->
        Logger.error(reason)
        :error
      _ ->
        :error
    end
  end

  def id(%{assigns: %{absinthe: %{opts: opts}}}) do
    user_id =
      opts
      |> Keyword.get(:context)
      |> Map.get(:current_account)
      |> Map.get(:id)

    "users_socket:#{user_id}"
  end

  defp encrypt({:push, {:binary, payload}, {_, %{assigns: %{client_key: key}}} = state}) do
    with {:ok, encrypted} <- Crypto.encrypt(:web_api, payload, key) do
      {:push, {:binary, encrypted}, state}
    else
      _ -> {:push, {:binary, payload}, state}
    end
  end
  defp encrypt({:reply, status, {:binary, payload}, {_, %{assigns: %{client_key: key}}} = state}) do
    with {:ok, encrypted} <- Crypto.encrypt(:web_api, payload, key) do
      {:reply, status, {:binary, encrypted}, state}
    else
      :error -> {:reply, status, {:binary, payload}, state}
    end
  end
  defp encrypt(result), do: result
end

also I use Msgpax to serialize message to binary form:

defmodule ProjectWeb.MsgPackSerializer do
  @moduledoc false
  @behaviour Phoenix.Transports.Serializer
  alias Phoenix.Socket.{Broadcast, Message, Reply}

  def fastlane!(%Broadcast{} = msg) do
    with {:ok, encoded} <- Msgpax.pack([nil, nil, msg.topic, msg.event, msg.payload]) do
      {:socket_push, :binary, IO.iodata_to_binary(encoded)}
    end
  end

  def encode!(%Reply{} = reply) do
    data = [
      reply.join_ref,
      reply.ref,
      reply.topic,
      "phx_reply",
      %{status: reply.status, response: reply.payload},
    ]
    with {:ok, encoded} <- Msgpax.pack(data) do
      {:socket_push, :binary, IO.iodata_to_binary(encoded)}
    end
  end
  def encode!(%Message{} = msg) do
    data = [
      msg.join_ref,
      msg.ref,
      msg.topic,
      msg.event,
      msg.payload,
    ]
    with {:ok, encoded} <- Msgpax.pack(data) do
      {:socket_push, :binary, IO.iodata_to_binary(encoded)}
    end
  end

  def decode!(message, _opts) do
    with {:ok, [join_ref, ref, topic, event, payload]} <- Msgpax.unpack(message) do
      %Message{
        topic: topic,
        event: event,
        payload: payload,
        ref: ref,
        join_ref: join_ref,
      }
    end
  end
end

and we also need to override standard transport in client phoenix library:

// project/phoenix.js

import { box } from 'tweetnacl';
import { decodeBase64 } from 'tweetnacl/utils';
import MsgPack from '@msgpack/msgpack';

export function encode({ join_ref: joinRef, ref, topic, event, payload }, callback) {
  return callback([joinRef, ref, topic, event, payload]);
}

export function decode(input, callback) {
  if (!Array.isArray(input) || input.length !== 5) {
    // eslint-disable-next-line no-console
    console.warn('invalid payload: ', input);
    return undefined;
  }
  const [joinRef, ref, topic, event, payload] = input;
  const msg = { join_ref: joinRef, ref, topic, event, payload };
  return callback(msg);
}

export function createTransport(salt, publicKey, secretKey) {
  salt = decodeBase64(salt);
  publicKey = decodeBase64(publicKey);

  return class Transport {
    constructor(endpoint) {
      this.ws = new WebSocket(endpoint);
      this.ws.binaryType = 'arraybuffer';
    }
  
    get onmessage() {
      return this.messageHandler;
    }
  
    set onmessage(callback) {
      this.messageHandler = event => {
        const data = new Uint8Array(event.data);
        const decrypted = box.open(data, salt, publicKey, secretKey);
        const decoded = MsgPack.decode(decrypted);
        console.log(decoded)
        return callback({ data: decoded });
      }
      this.ws.onmessage = this.messageHandler;
    }

    get onopen() {
      return this.ws.onopen;
    }
  
    set onopen(value) {
      this.ws.onopen = value;
    }
  
    get onerror() {
      return this.ws.onerror;
    }
  
    set onerror(value) {
      this.ws.onerror = value;
    }
  
    get onclose() {
      return this.ws.onclose;
    }
  
    set onclose(value) {
      this.ws.onclose = value;
    }
  
    get readyState() {
      return this.ws.readyState;
    }
  
    get binaryType() {
      return this.ws.binaryType;
    }

    set binaryType(_type) {
      return;
    }

    get bufferedAmount() {
      return this.ws.bufferedAmount;
    }

    get extensions() {
      return this.ws.extensions;
    }

    get protocol() {
      return this.ws.protocol;
    }
  
    get url() {
      return this.ws.url;
    }
  
    close(code, reason) {
      return this.ws.close(code, reason);
    }
  
    send(data) {
      const encoded = MsgPack.encode(data);
      const encrypted = box(encoded, salt, publicKey, secretKey);
      return this.ws.send(encrypted);
    }    
  }
}

and finally create a client:

// project/client.js

import ApolloClient from 'apollo-client';
import { InMemoryCache as Cache } from 'apollo-cache-inmemory';
import * as AbsintheSocket from '@absinthe/socket';
import { createAbsintheSocketLink } from '@absinthe/socket-apollo-link';
import { Socket as PhoenixSocket } from 'phoenix';
import { createTransport, encode, decode } from 'project/phoenix';
import { encodeBase64 } from 'tweetnacl/util';
import { box } from 'tweetnacl';

import Config from 'project/config';

let client;

export function privateClient(serverSalt, serverKey) {
  if (client) return client;
  const { publicKey, secretKey } = box.keyPair();
  const encodedPublicKey = encodeBase64(publicKey);
  const transport = createTransport(serverSalt, serverKey, secretKey);
  const socket = new PhoenixSocket(Config.api.SOCKET_URL, {
    transport,
    encode,
    decode,
    params: { key: encodedPublicKey, salt: serverSalt },
  });
  const link = createAbsintheSocketLink(AbsintheSocket.create(socket));
  const cache = new Cache();
  client = new ApolloClient({ link, cache });
  return client;
}